Skip to content

Commit d7d2904

Browse files
added some arithmetic ops
1 parent c442d85 commit d7d2904

File tree

1 file changed

+120
-7
lines changed

1 file changed

+120
-7
lines changed

src/python.rs

Lines changed: 120 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,26 @@
44
#![cfg(feature = "python-extension")]
55

66
use crate::algebraic_numbers::RealAlgebraicNumber;
7+
use crate::traits::ExactDivAssign;
78
use num_bigint::BigInt;
89
use num_bigint::Sign;
10+
use num_traits::Signed;
911
use num_traits::ToPrimitive;
1012
use num_traits::Zero;
13+
use pyo3::basic::CompareOp;
14+
use pyo3::exceptions::TypeError;
15+
use pyo3::exceptions::ValueError;
16+
use pyo3::exceptions::ZeroDivisionError;
1117
use pyo3::prelude::*;
1218
use pyo3::types::IntoPyDict;
1319
use pyo3::types::PyAny;
1420
use pyo3::types::PyBytes;
1521
use pyo3::types::PyInt;
1622
use pyo3::types::PyType;
1723
use pyo3::PyNativeType;
24+
use pyo3::PyNumberProtocol;
1825
use pyo3::PyObjectProtocol;
26+
use std::sync::Arc;
1927

2028
// TODO: Switch to using BigInt's python conversions once they are implemented
2129
// see https://github.com/PyO3/pyo3/issues/543
@@ -66,19 +74,31 @@ impl FromPyObject<'_> for PyBigInt {
6674
}
6775

6876
#[pyclass(name=RealAlgebraicNumber, module="algebraics")]
77+
#[derive(Clone)]
6978
struct RealAlgebraicNumberPy {
70-
value: RealAlgebraicNumber,
79+
value: Arc<RealAlgebraicNumber>,
7180
}
7281

73-
#[pymethods(PyObjectProtocol)]
74-
impl RealAlgebraicNumberPy {
75-
#[new]
76-
fn pynew(obj: &PyRawObject, value: Option<&PyInt>) -> PyResult<()> {
82+
impl FromPyObject<'_> for RealAlgebraicNumberPy {
83+
fn extract(value: &PyAny) -> PyResult<Self> {
84+
if let Ok(value) = value.downcast_ref::<RealAlgebraicNumberPy>() {
85+
return Ok(value.clone());
86+
}
87+
let value = value.extract::<Option<&PyInt>>()?;
7788
let value = match value {
7889
None => RealAlgebraicNumber::zero(),
7990
Some(value) => RealAlgebraicNumber::from(value.extract::<PyBigInt>()?.0),
80-
};
81-
obj.init(RealAlgebraicNumberPy { value });
91+
}
92+
.into();
93+
Ok(RealAlgebraicNumberPy { value })
94+
}
95+
}
96+
97+
#[pymethods(PyObjectProtocol, PyNumberProtocol)]
98+
impl RealAlgebraicNumberPy {
99+
#[new]
100+
fn pynew(obj: &PyRawObject, value: RealAlgebraicNumberPy) -> PyResult<()> {
101+
obj.init(value);
82102
Ok(())
83103
}
84104
// FIXME: implement rest of methods
@@ -89,6 +109,97 @@ impl PyObjectProtocol for RealAlgebraicNumberPy {
89109
fn __repr__(&self) -> PyResult<String> {
90110
Ok(format!("{:?}", self.value))
91111
}
112+
fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyResult<bool> {
113+
let py = other.py();
114+
let other = other.extract::<RealAlgebraicNumberPy>()?;
115+
Ok(py.allow_threads(|| match op {
116+
CompareOp::Lt => self.value < other.value,
117+
CompareOp::Le => self.value <= other.value,
118+
CompareOp::Eq => self.value == other.value,
119+
CompareOp::Ne => self.value != other.value,
120+
CompareOp::Gt => self.value > other.value,
121+
CompareOp::Ge => self.value >= other.value,
122+
}))
123+
}
124+
}
125+
126+
#[pyproto]
127+
impl PyNumberProtocol for RealAlgebraicNumberPy {
128+
fn __add__(lhs: &PyAny, rhs: RealAlgebraicNumberPy) -> PyResult<RealAlgebraicNumberPy> {
129+
let py = lhs.py();
130+
let mut lhs = lhs.extract::<RealAlgebraicNumberPy>()?;
131+
Ok(py.allow_threads(|| {
132+
*Arc::make_mut(&mut lhs.value) += &*rhs.value;
133+
lhs
134+
}))
135+
}
136+
fn __sub__(lhs: &PyAny, rhs: RealAlgebraicNumberPy) -> PyResult<RealAlgebraicNumberPy> {
137+
let py = lhs.py();
138+
let mut lhs = lhs.extract::<RealAlgebraicNumberPy>()?;
139+
Ok(py.allow_threads(|| {
140+
*Arc::make_mut(&mut lhs.value) -= &*rhs.value;
141+
lhs
142+
}))
143+
}
144+
fn __mul__(lhs: &PyAny, rhs: RealAlgebraicNumberPy) -> PyResult<RealAlgebraicNumberPy> {
145+
let py = lhs.py();
146+
let mut lhs = lhs.extract::<RealAlgebraicNumberPy>()?;
147+
Ok(py.allow_threads(|| {
148+
*Arc::make_mut(&mut lhs.value) *= &*rhs.value;
149+
lhs
150+
}))
151+
}
152+
fn __truediv__(lhs: &PyAny, rhs: RealAlgebraicNumberPy) -> PyResult<RealAlgebraicNumberPy> {
153+
let py = lhs.py();
154+
let mut lhs = lhs.extract::<RealAlgebraicNumberPy>()?;
155+
py.allow_threads(|| -> Result<RealAlgebraicNumberPy, ()> {
156+
Arc::make_mut(&mut lhs.value).checked_exact_div_assign(&*rhs.value)?;
157+
Ok(lhs)
158+
})
159+
.map_err(|()| ZeroDivisionError::py_err("can't divide RealAlgebraicNumber by zero"))
160+
}
161+
fn __pow__(
162+
lhs: RealAlgebraicNumberPy,
163+
rhs: RealAlgebraicNumberPy,
164+
modulus: &PyAny,
165+
) -> PyResult<RealAlgebraicNumberPy> {
166+
let py = modulus.py();
167+
if !modulus.is_none() {
168+
return Err(TypeError::py_err(
169+
"3 argument pow() not allowed for RealAlgebraicNumber",
170+
));
171+
}
172+
py.allow_threads(|| -> Result<RealAlgebraicNumberPy, &'static str> {
173+
if let Some(rhs) = rhs.value.to_rational() {
174+
Ok(RealAlgebraicNumberPy {
175+
value: lhs
176+
.value
177+
.checked_pow(rhs)
178+
.ok_or("pow() failed for RealAlgebraicNumber")?
179+
.into(),
180+
})
181+
} else {
182+
Err("exponent must be rational for RealAlgebraicNumber")
183+
}
184+
})
185+
.map_err(ValueError::py_err)
186+
}
187+
188+
// Unary arithmetic
189+
fn __neg__(&self) -> PyResult<RealAlgebraicNumberPy> {
190+
Ok(Python::acquire_gil()
191+
.python()
192+
.allow_threads(|| RealAlgebraicNumberPy {
193+
value: Arc::from(-&*self.value),
194+
}))
195+
}
196+
fn __abs__(&self) -> PyResult<RealAlgebraicNumberPy> {
197+
Ok(Python::acquire_gil()
198+
.python()
199+
.allow_threads(|| RealAlgebraicNumberPy {
200+
value: self.value.abs().into(),
201+
}))
202+
}
92203
}
93204

94205
#[pymodule]
@@ -97,3 +208,5 @@ fn algebraics(_py: Python, m: &PyModule) -> PyResult<()> {
97208
m.add_class::<RealAlgebraicNumberPy>()?;
98209
Ok(())
99210
}
211+
212+
// FIXME: add tests

0 commit comments

Comments
 (0)