File tree Expand file tree Collapse file tree 1 file changed +22
-1
lines changed
Expand file tree Collapse file tree 1 file changed +22
-1
lines changed Original file line number Diff line number Diff line change @@ -42,6 +42,25 @@ impl Coerce for AllowTypeChange {
4242 const VAL : bool = true ;
4343}
4444
45+ trait IsNumpyNDArray {
46+ fn is_numpy_ndarray ( & self ) -> PyResult < bool > ;
47+ }
48+
49+ impl < ' py > IsNumpyNDArray for Borrowed < ' _ , ' py , PyAny > {
50+ fn is_numpy_ndarray ( & self ) -> PyResult < bool > {
51+ let py = self . py ( ) ;
52+
53+ static NDARRAY : PyOnceLock < Py < PyAny > > = PyOnceLock :: new ( ) ;
54+ let ndarray = NDARRAY
55+ . get_or_try_init ( py, || {
56+ get_array_module ( py) ?. getattr ( "ndarray" ) . map ( Into :: into)
57+ } ) ?
58+ . bind ( py) ;
59+
60+ self . is_instance ( ndarray)
61+ }
62+ }
63+
4564/// Receiver for arrays or array-like types.
4665///
4766/// When building API using NumPy in Python, it is common for functions to additionally accept any array-like type such as `list[float]` as arguments.
@@ -151,7 +170,9 @@ where
151170
152171 let py = ob. py ( ) ;
153172
154- if matches ! ( D :: NDIM , None | Some ( 1 ) ) {
173+ // If the input is already an ndarray and `TypeMustMatch` is used then any conversion
174+ // should be performed.
175+ if ( C :: VAL || !ob. is_numpy_ndarray ( ) ?) && matches ! ( D :: NDIM , None | Some ( 1 ) ) {
155176 if let Ok ( vec) = ob. extract :: < Vec < T > > ( ) {
156177 let array = Array1 :: from ( vec)
157178 . into_dimensionality ( )
You can’t perform that action at this time.
0 commit comments