@@ -136,12 +136,13 @@ impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
136
136
}
137
137
& * ( ob as * const PyAny as * const PyArray < T , D > )
138
138
} ;
139
- let dtype = array. dtype ( ) ;
139
+ let src_dtype = array. dtype ( ) ;
140
+ let dst_dtype = T :: get_dtype ( ob. py ( ) ) ;
140
141
let dim = array. shape ( ) . len ( ) ;
141
- if T :: is_same_type ( dtype ) && D :: NDIM . map ( |n| n == dim) . unwrap_or ( true ) {
142
+ if src_dtype . is_equiv_to ( dst_dtype ) && D :: NDIM . map ( |n| n == dim) . unwrap_or ( true ) {
142
143
Ok ( array)
143
144
} else {
144
- Err ( ShapeError :: new ( dtype , dim, T :: DATA_TYPE , D :: NDIM ) . into ( ) )
145
+ Err ( ShapeError :: new ( src_dtype , dim, T :: DATA_TYPE , D :: NDIM ) . into ( ) )
145
146
}
146
147
}
147
148
}
@@ -457,10 +458,10 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
457
458
T :: get_dtype ( py) . into_dtype_ptr ( ) ,
458
459
dims. ndim_cint ( ) ,
459
460
dims. as_dims_ptr ( ) ,
460
- strides as * mut npy_intp , // strides
461
- data_ptr as * mut c_void , // data
462
- npyffi:: NPY_ARRAY_WRITEABLE , // flag
463
- ptr:: null_mut ( ) , // obj
461
+ strides as * mut npy_intp , // strides
462
+ data_ptr as * mut c_void , // data
463
+ npyffi:: NPY_ARRAY_WRITEABLE , // flag
464
+ ptr:: null_mut ( ) , // obj
464
465
) ;
465
466
466
467
PY_ARRAY_API . PyArray_SetBaseObject (
0 commit comments