diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fd49a2ba..97707dc45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,7 @@ # Changelog +- v0.28.0 + - Fix mismatched behavior between `PyArrayLike1` and `PyArrayLike2` when used with floats ([#520](https://github.com/PyO3/rust-numpy/pull/520)) + - v0.27.1 - Bump ndarray dependency to v0.17. ([#516](https://github.com/PyO3/rust-numpy/pull/516)) diff --git a/src/array_like.rs b/src/array_like.rs index 0419c1438..dfea56165 100644 --- a/src/array_like.rs +++ b/src/array_like.rs @@ -10,10 +10,10 @@ use pyo3::{ }; use crate::array::PyArrayMethods; -use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray}; +use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray, PyUntypedArray}; pub trait Coerce: Sealed { - const VAL: bool; + const ALLOW_TYPE_CHANGE: bool; } mod sealed { @@ -29,7 +29,7 @@ pub struct TypeMustMatch; impl Sealed for TypeMustMatch {} impl Coerce for TypeMustMatch { - const VAL: bool = false; + const ALLOW_TYPE_CHANGE: bool = false; } /// Marker type to indicate that the element type received via [`PyArrayLike`] can be cast to the specified type by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html). @@ -39,7 +39,7 @@ pub struct AllowTypeChange; impl Sealed for AllowTypeChange {} impl Coerce for AllowTypeChange { - const VAL: bool = true; + const ALLOW_TYPE_CHANGE: bool = true; } /// Receiver for arrays or array-like types. @@ -151,7 +151,11 @@ where let py = ob.py(); - if matches!(D::NDIM, None | Some(1)) { + // If the input is already an ndarray and `TypeMustMatch` is used then no type conversion + // should be performed. + if (C::ALLOW_TYPE_CHANGE || ob.cast::().is_err()) + && matches!(D::NDIM, None | Some(1)) + { if let Ok(vec) = ob.extract::>() { let array = Array1::from(vec) .into_dimensionality() @@ -170,7 +174,7 @@ where })? .bind(py); - let kwargs = if C::VAL { + let kwargs = if C::ALLOW_TYPE_CHANGE { let kwargs = PyDict::new(py); kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?; Some(kwargs) diff --git a/tests/array_like.rs b/tests/array_like.rs index 9a2afcfdf..d08e98abf 100644 --- a/tests/array_like.rs +++ b/tests/array_like.rs @@ -132,6 +132,50 @@ fn unsafe_cast_shall_fail() { }); } +#[test] +fn extract_1d_array_of_different_float_types_fail() { + Python::attach(|py| { + let locals = get_np_locals(py); + let py_list = py + .eval( + c_str!("np.array([1, 2, 3, 4], dtype='float64')"), + Some(&locals), + None, + ) + .unwrap(); + let extracted_array_f32 = py_list.extract::>(); + let extracted_array_f64 = py_list.extract::>().unwrap(); + + assert!(extracted_array_f32.is_err()); + assert_eq!( + array![1_f64, 2_f64, 3_f64, 4_f64], + extracted_array_f64.as_array() + ); + }); +} + +#[test] +fn extract_2d_array_of_different_float_types_fail() { + Python::attach(|py| { + let locals = get_np_locals(py); + let py_list = py + .eval( + c_str!("np.array([[1, 2], [3, 4]], dtype='float64')"), + Some(&locals), + None, + ) + .unwrap(); + let extracted_array_f32 = py_list.extract::>(); + let extracted_array_f64 = py_list.extract::>().unwrap(); + + assert!(extracted_array_f32.is_err()); + assert_eq!( + array![[1_f64, 2_f64], [3_f64, 4_f64]], + extracted_array_f64.as_array() + ); + }); +} + #[test] fn unsafe_cast_with_coerce_works() { Python::attach(|py| {