Skip to content

Commit 5d55c70

Browse files
committed
Use is_equiv_to() in FromPyObject::extract()
1 parent 612e867 commit 5d55c70

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

src/array.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,13 @@ impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
136136
}
137137
&*(ob as *const PyAny as *const PyArray<T, D>)
138138
};
139-
let dtype = array.dtype();
139+
let src_dtype = array.dtype();
140+
let dst_dtype = T::get_dtype(ob.py());
140141
let dim = array.shape().len();
141-
if T::is_same_type(dtype) && D::NDIM.map(|n| n == dim).unwrap_or(true) {
142+
if src_dtype.is_equiv_to(dst_dtype) && D::NDIM.map(|n| n == dim).unwrap_or(true) {
142143
Ok(array)
143144
} else {
144-
Err(ShapeError::new(dtype, dim, T::DATA_TYPE, D::NDIM).into())
145+
Err(ShapeError::new(src_dtype, dim, T::DATA_TYPE, D::NDIM).into())
145146
}
146147
}
147148
}
@@ -457,10 +458,10 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
457458
T::get_dtype(py).into_dtype_ptr(),
458459
dims.ndim_cint(),
459460
dims.as_dims_ptr(),
460-
strides as *mut npy_intp, // strides
461-
data_ptr as *mut c_void, // data
462-
npyffi::NPY_ARRAY_WRITEABLE, // flag
463-
ptr::null_mut(), // obj
461+
strides as *mut npy_intp, // strides
462+
data_ptr as *mut c_void, // data
463+
npyffi::NPY_ARRAY_WRITEABLE, // flag
464+
ptr::null_mut(), // obj
464465
);
465466

466467
PY_ARRAY_API.PyArray_SetBaseObject(

0 commit comments

Comments
 (0)