Skip to content

Commit 7357a1e

Browse files
committed
Implement PyArray::cast
1 parent cd3f156 commit 7357a1e

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

src/array.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,4 +515,35 @@ impl<T: TypeNum> PyArray<T> {
515515
Ok(())
516516
}
517517
}
518+
519+
/// Cast the `PyArray<T>` to `PyArray<U>`, by allocating a new array.
520+
/// # Example
521+
/// ```
522+
/// # extern crate pyo3; extern crate numpy; fn main() {
523+
/// use numpy::{PyArray, PyArrayModule, IntoPyArray};
524+
/// let gil = pyo3::Python::acquire_gil();
525+
/// let np = PyArrayModule::import(gil.python()).unwrap();
526+
/// let pyarray_f = PyArray::<f64>::arange(gil.python(), &np, 2.0, 5.0, 1.0);
527+
/// let pyarray_i = pyarray_f.cast::<i32>(gil.python(), &np, false).unwrap();
528+
/// assert_eq!(pyarray_i.as_slice().unwrap(), &[2, 3, 4]);
529+
/// # }
530+
pub fn cast<U: TypeNum>(
531+
&self,
532+
py: Python,
533+
np: &PyArrayModule,
534+
is_fortran: bool,
535+
) -> Result<PyArray<U>, ArrayCastError> {
536+
let ptr = unsafe {
537+
let descr = np.PyArray_DescrFromType(U::typenum_default());
538+
np.PyArray_CastToType(self.as_array_ptr(), descr, if is_fortran { -1 } else { 0 })
539+
};
540+
if ptr.is_null() {
541+
Err(ArrayCastError::Numpy {
542+
from: T::npy_data_type(),
543+
to: U::npy_data_type(),
544+
})
545+
} else {
546+
unsafe { Ok(PyArray::<U>::from_owned_ptr(py, ptr)) }
547+
}
548+
}
518549
}

0 commit comments

Comments
 (0)