Skip to content

Commit e4646ba

Browse files
committed
Avoid unexpected conversion with PyArrayLike1
1 parent 06d45ed commit e4646ba

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

src/array_like.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,25 @@ impl Coerce for AllowTypeChange {
4242
const VAL: bool = true;
4343
}
4444

45+
trait IsNumpyNDArray {
46+
fn is_numpy_ndarray(&self) -> PyResult<bool>;
47+
}
48+
49+
impl<'py> IsNumpyNDArray for Borrowed<'_, 'py, PyAny> {
50+
fn is_numpy_ndarray(&self) -> PyResult<bool> {
51+
let py = self.py();
52+
53+
static NDARRAY: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
54+
let ndarray = NDARRAY
55+
.get_or_try_init(py, || {
56+
get_array_module(py)?.getattr("ndarray").map(Into::into)
57+
})?
58+
.bind(py);
59+
60+
self.is_instance(ndarray)
61+
}
62+
}
63+
4564
/// Receiver for arrays or array-like types.
4665
///
4766
/// When building API using NumPy in Python, it is common for functions to additionally accept any array-like type such as `list[float]` as arguments.
@@ -151,7 +170,9 @@ where
151170

152171
let py = ob.py();
153172

154-
if matches!(D::NDIM, None | Some(1)) {
173+
// If the input is already an ndarray and `TypeMustMatch` is used then any conversion
174+
// should be performed.
175+
if (C::VAL || !ob.is_numpy_ndarray()?) && matches!(D::NDIM, None | Some(1)) {
155176
if let Ok(vec) = ob.extract::<Vec<T>>() {
156177
let array = Array1::from(vec)
157178
.into_dimensionality()

0 commit comments

Comments
 (0)