Skip to content

Commit 93a3aaf

Browse files
authored
Merge pull request #330 from PyO3/recovering-polymorphism
RFC: Extend simple exmaple to include a function with limited polymorphism based on enums and FromPyObject.
2 parents 477c9d4 + 5f79d24 commit 93a3aaf

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

examples/simple/src/lib.rs

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD, Zip};
1+
use std::ops::Add;
2+
3+
use numpy::ndarray::{Array1, ArrayD, ArrayView1, ArrayViewD, ArrayViewMutD, Zip};
24
use numpy::{
35
datetime::{units, Timedelta},
46
Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArray1, PyReadonlyArrayDyn,
@@ -7,7 +9,7 @@ use numpy::{
79
use pyo3::{
810
pymodule,
911
types::{PyDict, PyModule},
10-
PyResult, Python,
12+
FromPyObject, PyAny, PyResult, Python,
1113
};
1214

1315
#[pymodule]
@@ -27,6 +29,11 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
2729
x.map(|c| c.conj())
2830
}
2931

32+
// example using generics
33+
fn generic_add<T: Copy + Add<Output = T>>(x: ArrayView1<T>, y: ArrayView1<T>) -> Array1<T> {
34+
&x + &y
35+
}
36+
3037
// wrapper of `axpy`
3138
#[pyfn(m)]
3239
#[pyo3(name = "axpy")]
@@ -84,5 +91,47 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
8491
.apply(|x, y| *x = (i64::from(*x) + 60 * i64::from(*y)).into());
8592
}
8693

94+
// This crate follows a strongly-typed approach to wrapping NumPy arrays
95+
// while Python API are often expected to work with multiple element types.
96+
//
97+
// That kind of limited polymorphis can be recovered by accepting an enumerated type
98+
// covering the supported element types and dispatching into a generic implementation.
99+
#[derive(FromPyObject)]
100+
enum SupportedArray<'py> {
101+
F64(&'py PyArray1<f64>),
102+
I64(&'py PyArray1<i64>),
103+
}
104+
105+
#[pyfn(m)]
106+
fn polymorphic_add<'py>(
107+
x: SupportedArray<'py>,
108+
y: SupportedArray<'py>,
109+
) -> PyResult<&'py PyAny> {
110+
match (x, y) {
111+
(SupportedArray::F64(x), SupportedArray::F64(y)) => Ok(generic_add(
112+
x.readonly().as_array(),
113+
y.readonly().as_array(),
114+
)
115+
.into_pyarray(x.py())
116+
.into()),
117+
(SupportedArray::I64(x), SupportedArray::I64(y)) => Ok(generic_add(
118+
x.readonly().as_array(),
119+
y.readonly().as_array(),
120+
)
121+
.into_pyarray(x.py())
122+
.into()),
123+
(SupportedArray::F64(x), SupportedArray::I64(y))
124+
| (SupportedArray::I64(y), SupportedArray::F64(x)) => {
125+
let y = y.cast::<f64>(false)?;
126+
127+
Ok(
128+
generic_add(x.readonly().as_array(), y.readonly().as_array())
129+
.into_pyarray(x.py())
130+
.into(),
131+
)
132+
}
133+
}
134+
}
135+
87136
Ok(())
88137
}

examples/simple/tests/test_ext.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from rust_ext import axpy, conj, mult, extract, add_minutes_to_seconds
2+
from rust_ext import axpy, conj, mult, extract, add_minutes_to_seconds, polymorphic_add
33

44

55
def test_axpy():
@@ -33,3 +33,20 @@ def test_add_minutes_to_seconds():
3333
add_minutes_to_seconds(x, y)
3434

3535
assert np.all(x == np.array([70, 140, 210], dtype="timedelta64[s]"))
36+
37+
38+
def test_polymorphic_add():
39+
x = np.array([1.0, 2.0, 3.0], dtype=np.double)
40+
y = np.array([3.0, 3.0, 3.0], dtype=np.double)
41+
z = polymorphic_add(x, y)
42+
np.testing.assert_array_almost_equal(z, np.array([4.0, 5.0, 6.0], dtype=np.double))
43+
44+
x = np.array([1, 2, 3], dtype=np.int64)
45+
y = np.array([3, 3, 3], dtype=np.int64)
46+
z = polymorphic_add(x, y)
47+
assert np.all(z == np.array([4, 5, 6], dtype=np.int64))
48+
49+
x = np.array([1.0, 2.0, 3.0], dtype=np.double)
50+
y = np.array([3, 3, 3], dtype=np.int64)
51+
z = polymorphic_add(x, y)
52+
np.testing.assert_array_almost_equal(z, np.array([4.0, 5.0, 6.0], dtype=np.double))

0 commit comments

Comments
 (0)