1
- use numpy:: ndarray:: { ArrayD , ArrayViewD , ArrayViewMutD , Zip } ;
1
+ use std:: ops:: Add ;
2
+
3
+ use numpy:: ndarray:: { Array1 , ArrayD , ArrayView1 , ArrayViewD , ArrayViewMutD , Zip } ;
2
4
use numpy:: {
3
5
datetime:: { units, Timedelta } ,
4
6
Complex64 , IntoPyArray , PyArray1 , PyArrayDyn , PyReadonlyArray1 , PyReadonlyArrayDyn ,
@@ -7,7 +9,7 @@ use numpy::{
7
9
use pyo3:: {
8
10
pymodule,
9
11
types:: { PyDict , PyModule } ,
10
- PyResult , Python ,
12
+ FromPyObject , PyAny , PyResult , Python ,
11
13
} ;
12
14
13
15
#[ pymodule]
@@ -27,6 +29,11 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
27
29
x. map ( |c| c. conj ( ) )
28
30
}
29
31
32
+ // example using generics
33
+ fn generic_add < T : Copy + Add < Output = T > > ( x : ArrayView1 < T > , y : ArrayView1 < T > ) -> Array1 < T > {
34
+ & x + & y
35
+ }
36
+
30
37
// wrapper of `axpy`
31
38
#[ pyfn( m) ]
32
39
#[ pyo3( name = "axpy" ) ]
@@ -84,5 +91,47 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
84
91
. apply ( |x, y| * x = ( i64:: from ( * x) + 60 * i64:: from ( * y) ) . into ( ) ) ;
85
92
}
86
93
94
+ // This crate follows a strongly-typed approach to wrapping NumPy arrays
95
+ // while Python API are often expected to work with multiple element types.
96
+ //
97
+ // That kind of limited polymorphis can be recovered by accepting an enumerated type
98
+ // covering the supported element types and dispatching into a generic implementation.
99
+ #[ derive( FromPyObject ) ]
100
+ enum SupportedArray < ' py > {
101
+ F64 ( & ' py PyArray1 < f64 > ) ,
102
+ I64 ( & ' py PyArray1 < i64 > ) ,
103
+ }
104
+
105
+ #[ pyfn( m) ]
106
+ fn polymorphic_add < ' py > (
107
+ x : SupportedArray < ' py > ,
108
+ y : SupportedArray < ' py > ,
109
+ ) -> PyResult < & ' py PyAny > {
110
+ match ( x, y) {
111
+ ( SupportedArray :: F64 ( x) , SupportedArray :: F64 ( y) ) => Ok ( generic_add (
112
+ x. readonly ( ) . as_array ( ) ,
113
+ y. readonly ( ) . as_array ( ) ,
114
+ )
115
+ . into_pyarray ( x. py ( ) )
116
+ . into ( ) ) ,
117
+ ( SupportedArray :: I64 ( x) , SupportedArray :: I64 ( y) ) => Ok ( generic_add (
118
+ x. readonly ( ) . as_array ( ) ,
119
+ y. readonly ( ) . as_array ( ) ,
120
+ )
121
+ . into_pyarray ( x. py ( ) )
122
+ . into ( ) ) ,
123
+ ( SupportedArray :: F64 ( x) , SupportedArray :: I64 ( y) )
124
+ | ( SupportedArray :: I64 ( y) , SupportedArray :: F64 ( x) ) => {
125
+ let y = y. cast :: < f64 > ( false ) ?;
126
+
127
+ Ok (
128
+ generic_add ( x. readonly ( ) . as_array ( ) , y. readonly ( ) . as_array ( ) )
129
+ . into_pyarray ( x. py ( ) )
130
+ . into ( ) ,
131
+ )
132
+ }
133
+ }
134
+ }
135
+
87
136
Ok ( ( ) )
88
137
}
0 commit comments