4
4
#![ cfg( feature = "python-extension" ) ]
5
5
6
6
use crate :: algebraic_numbers:: RealAlgebraicNumber ;
7
+ use crate :: traits:: ExactDivAssign ;
7
8
use num_bigint:: BigInt ;
8
9
use num_bigint:: Sign ;
10
+ use num_traits:: Signed ;
9
11
use num_traits:: ToPrimitive ;
10
12
use num_traits:: Zero ;
13
+ use pyo3:: basic:: CompareOp ;
14
+ use pyo3:: exceptions:: TypeError ;
15
+ use pyo3:: exceptions:: ValueError ;
16
+ use pyo3:: exceptions:: ZeroDivisionError ;
11
17
use pyo3:: prelude:: * ;
12
18
use pyo3:: types:: IntoPyDict ;
13
19
use pyo3:: types:: PyAny ;
14
20
use pyo3:: types:: PyBytes ;
15
21
use pyo3:: types:: PyInt ;
16
22
use pyo3:: types:: PyType ;
17
23
use pyo3:: PyNativeType ;
24
+ use pyo3:: PyNumberProtocol ;
18
25
use pyo3:: PyObjectProtocol ;
26
+ use std:: sync:: Arc ;
19
27
20
28
// TODO: Switch to using BigInt's python conversions once they are implemented
21
29
// see https://github.com/PyO3/pyo3/issues/543
@@ -66,19 +74,31 @@ impl FromPyObject<'_> for PyBigInt {
66
74
}
67
75
68
76
#[ pyclass( name=RealAlgebraicNumber , module="algebraics" ) ]
77
+ #[ derive( Clone ) ]
69
78
struct RealAlgebraicNumberPy {
70
- value : RealAlgebraicNumber ,
79
+ value : Arc < RealAlgebraicNumber > ,
71
80
}
72
81
73
- #[ pymethods( PyObjectProtocol ) ]
74
- impl RealAlgebraicNumberPy {
75
- #[ new]
76
- fn pynew ( obj : & PyRawObject , value : Option < & PyInt > ) -> PyResult < ( ) > {
82
+ impl FromPyObject < ' _ > for RealAlgebraicNumberPy {
83
+ fn extract ( value : & PyAny ) -> PyResult < Self > {
84
+ if let Ok ( value) = value. downcast_ref :: < RealAlgebraicNumberPy > ( ) {
85
+ return Ok ( value. clone ( ) ) ;
86
+ }
87
+ let value = value. extract :: < Option < & PyInt > > ( ) ?;
77
88
let value = match value {
78
89
None => RealAlgebraicNumber :: zero ( ) ,
79
90
Some ( value) => RealAlgebraicNumber :: from ( value. extract :: < PyBigInt > ( ) ?. 0 ) ,
80
- } ;
81
- obj. init ( RealAlgebraicNumberPy { value } ) ;
91
+ }
92
+ . into ( ) ;
93
+ Ok ( RealAlgebraicNumberPy { value } )
94
+ }
95
+ }
96
+
97
+ #[ pymethods( PyObjectProtocol , PyNumberProtocol ) ]
98
+ impl RealAlgebraicNumberPy {
99
+ #[ new]
100
+ fn pynew ( obj : & PyRawObject , value : RealAlgebraicNumberPy ) -> PyResult < ( ) > {
101
+ obj. init ( value) ;
82
102
Ok ( ( ) )
83
103
}
84
104
// FIXME: implement rest of methods
@@ -89,6 +109,97 @@ impl PyObjectProtocol for RealAlgebraicNumberPy {
89
109
fn __repr__ ( & self ) -> PyResult < String > {
90
110
Ok ( format ! ( "{:?}" , self . value) )
91
111
}
112
+ fn __richcmp__ ( & self , other : & PyAny , op : CompareOp ) -> PyResult < bool > {
113
+ let py = other. py ( ) ;
114
+ let other = other. extract :: < RealAlgebraicNumberPy > ( ) ?;
115
+ Ok ( py. allow_threads ( || match op {
116
+ CompareOp :: Lt => self . value < other. value ,
117
+ CompareOp :: Le => self . value <= other. value ,
118
+ CompareOp :: Eq => self . value == other. value ,
119
+ CompareOp :: Ne => self . value != other. value ,
120
+ CompareOp :: Gt => self . value > other. value ,
121
+ CompareOp :: Ge => self . value >= other. value ,
122
+ } ) )
123
+ }
124
+ }
125
+
126
+ #[ pyproto]
127
+ impl PyNumberProtocol for RealAlgebraicNumberPy {
128
+ fn __add__ ( lhs : & PyAny , rhs : RealAlgebraicNumberPy ) -> PyResult < RealAlgebraicNumberPy > {
129
+ let py = lhs. py ( ) ;
130
+ let mut lhs = lhs. extract :: < RealAlgebraicNumberPy > ( ) ?;
131
+ Ok ( py. allow_threads ( || {
132
+ * Arc :: make_mut ( & mut lhs. value ) += & * rhs. value ;
133
+ lhs
134
+ } ) )
135
+ }
136
+ fn __sub__ ( lhs : & PyAny , rhs : RealAlgebraicNumberPy ) -> PyResult < RealAlgebraicNumberPy > {
137
+ let py = lhs. py ( ) ;
138
+ let mut lhs = lhs. extract :: < RealAlgebraicNumberPy > ( ) ?;
139
+ Ok ( py. allow_threads ( || {
140
+ * Arc :: make_mut ( & mut lhs. value ) -= & * rhs. value ;
141
+ lhs
142
+ } ) )
143
+ }
144
+ fn __mul__ ( lhs : & PyAny , rhs : RealAlgebraicNumberPy ) -> PyResult < RealAlgebraicNumberPy > {
145
+ let py = lhs. py ( ) ;
146
+ let mut lhs = lhs. extract :: < RealAlgebraicNumberPy > ( ) ?;
147
+ Ok ( py. allow_threads ( || {
148
+ * Arc :: make_mut ( & mut lhs. value ) *= & * rhs. value ;
149
+ lhs
150
+ } ) )
151
+ }
152
+ fn __truediv__ ( lhs : & PyAny , rhs : RealAlgebraicNumberPy ) -> PyResult < RealAlgebraicNumberPy > {
153
+ let py = lhs. py ( ) ;
154
+ let mut lhs = lhs. extract :: < RealAlgebraicNumberPy > ( ) ?;
155
+ py. allow_threads ( || -> Result < RealAlgebraicNumberPy , ( ) > {
156
+ Arc :: make_mut ( & mut lhs. value ) . checked_exact_div_assign ( & * rhs. value ) ?;
157
+ Ok ( lhs)
158
+ } )
159
+ . map_err ( |( ) | ZeroDivisionError :: py_err ( "can't divide RealAlgebraicNumber by zero" ) )
160
+ }
161
+ fn __pow__ (
162
+ lhs : RealAlgebraicNumberPy ,
163
+ rhs : RealAlgebraicNumberPy ,
164
+ modulus : & PyAny ,
165
+ ) -> PyResult < RealAlgebraicNumberPy > {
166
+ let py = modulus. py ( ) ;
167
+ if !modulus. is_none ( ) {
168
+ return Err ( TypeError :: py_err (
169
+ "3 argument pow() not allowed for RealAlgebraicNumber" ,
170
+ ) ) ;
171
+ }
172
+ py. allow_threads ( || -> Result < RealAlgebraicNumberPy , & ' static str > {
173
+ if let Some ( rhs) = rhs. value . to_rational ( ) {
174
+ Ok ( RealAlgebraicNumberPy {
175
+ value : lhs
176
+ . value
177
+ . checked_pow ( rhs)
178
+ . ok_or ( "pow() failed for RealAlgebraicNumber" ) ?
179
+ . into ( ) ,
180
+ } )
181
+ } else {
182
+ Err ( "exponent must be rational for RealAlgebraicNumber" )
183
+ }
184
+ } )
185
+ . map_err ( ValueError :: py_err)
186
+ }
187
+
188
+ // Unary arithmetic
189
+ fn __neg__ ( & self ) -> PyResult < RealAlgebraicNumberPy > {
190
+ Ok ( Python :: acquire_gil ( )
191
+ . python ( )
192
+ . allow_threads ( || RealAlgebraicNumberPy {
193
+ value : Arc :: from ( -& * self . value ) ,
194
+ } ) )
195
+ }
196
+ fn __abs__ ( & self ) -> PyResult < RealAlgebraicNumberPy > {
197
+ Ok ( Python :: acquire_gil ( )
198
+ . python ( )
199
+ . allow_threads ( || RealAlgebraicNumberPy {
200
+ value : self . value . abs ( ) . into ( ) ,
201
+ } ) )
202
+ }
92
203
}
93
204
94
205
#[ pymodule]
@@ -97,3 +208,5 @@ fn algebraics(_py: Python, m: &PyModule) -> PyResult<()> {
97
208
m. add_class :: < RealAlgebraicNumberPy > ( ) ?;
98
209
Ok ( ( ) )
99
210
}
211
+
212
+ // FIXME: add tests
0 commit comments