Skip to content

Commit e544705

Browse files
authored
Merge pull request #60 from kngwyu/cast
Implement copy and cast methods
2 parents 0b941e4 + 4df8836 commit e544705

File tree

3 files changed

+117
-8
lines changed

3 files changed

+117
-8
lines changed

src/array.rs

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,4 +455,95 @@ impl<T: TypeNum> PyArray<T> {
455455
Self::from_owned_ptr(py, ptr)
456456
}
457457
}
458+
459+
/// Copies self into `other`, performing a data-type conversion if necessary.
460+
/// # Example
461+
/// ```
462+
/// # extern crate pyo3; extern crate numpy; fn main() {
463+
/// use numpy::{PyArray, PyArrayModule, IntoPyArray};
464+
/// let gil = pyo3::Python::acquire_gil();
465+
/// let np = PyArrayModule::import(gil.python()).unwrap();
466+
/// let pyarray_f = PyArray::<f64>::arange(gil.python(), &np, 2.0, 5.0, 1.0);
467+
/// let mut pyarray_i = PyArray::<i64>::new(gil.python(), &np, &[3]);
468+
/// assert!(pyarray_f.copy_to(&np, &mut pyarray_i).is_ok());
469+
/// assert_eq!(pyarray_i.as_slice().unwrap(), &[2, 3, 4]);
470+
/// # }
471+
pub fn copy_to<U: TypeNum>(
472+
&self,
473+
np: &PyArrayModule,
474+
other: &mut PyArray<U>,
475+
) -> Result<(), ArrayCastError> {
476+
let self_ptr = self.as_array_ptr();
477+
let other_ptr = other.as_array_ptr();
478+
let result = unsafe { np.PyArray_CopyInto(other_ptr, self_ptr) };
479+
if result == -1 {
480+
Err(ArrayCastError::Numpy {
481+
from: T::npy_data_type(),
482+
to: U::npy_data_type(),
483+
})
484+
} else {
485+
Ok(())
486+
}
487+
}
488+
489+
/// Move the data of self into `other`, performing a data-type conversion if necessary.
490+
/// # Example
491+
/// ```
492+
/// # extern crate pyo3; extern crate numpy; fn main() {
493+
/// use numpy::{PyArray, PyArrayModule, IntoPyArray};
494+
/// let gil = pyo3::Python::acquire_gil();
495+
/// let np = PyArrayModule::import(gil.python()).unwrap();
496+
/// let pyarray_f = PyArray::<f64>::arange(gil.python(), &np, 2.0, 5.0, 1.0);
497+
/// let mut pyarray_i = PyArray::<i64>::new(gil.python(), &np, &[3]);
498+
/// assert!(pyarray_f.move_to(&np, &mut pyarray_i).is_ok());
499+
/// assert_eq!(pyarray_i.as_slice().unwrap(), &[2, 3, 4]);
500+
/// # }
501+
pub fn move_to<U: TypeNum>(
502+
self,
503+
np: &PyArrayModule,
504+
other: &mut PyArray<U>,
505+
) -> Result<(), ArrayCastError> {
506+
let self_ptr = self.as_array_ptr();
507+
let other_ptr = other.as_array_ptr();
508+
let result = unsafe { np.PyArray_MoveInto(other_ptr, self_ptr) };
509+
if result == -1 {
510+
Err(ArrayCastError::Numpy {
511+
from: T::npy_data_type(),
512+
to: U::npy_data_type(),
513+
})
514+
} else {
515+
Ok(())
516+
}
517+
}
518+
519+
/// Cast the `PyArray<T>` to `PyArray<U>`, by allocating a new array.
520+
/// # Example
521+
/// ```
522+
/// # extern crate pyo3; extern crate numpy; fn main() {
523+
/// use numpy::{PyArray, PyArrayModule, IntoPyArray};
524+
/// let gil = pyo3::Python::acquire_gil();
525+
/// let np = PyArrayModule::import(gil.python()).unwrap();
526+
/// let pyarray_f = PyArray::<f64>::arange(gil.python(), &np, 2.0, 5.0, 1.0);
527+
/// let pyarray_i = pyarray_f.cast::<i32>(gil.python(), &np, false).unwrap();
528+
/// assert_eq!(pyarray_i.as_slice().unwrap(), &[2, 3, 4]);
529+
/// # }
530+
pub fn cast<U: TypeNum>(
531+
&self,
532+
py: Python,
533+
np: &PyArrayModule,
534+
is_fortran: bool,
535+
) -> Result<PyArray<U>, ArrayCastError> {
536+
let ptr = unsafe {
537+
let descr = np.PyArray_DescrFromType(U::typenum_default());
538+
np.PyArray_CastToType(self.as_array_ptr(), descr, if is_fortran { -1 } else { 0 })
539+
};
540+
if ptr.is_null() {
541+
Err(ArrayCastError::Numpy {
542+
from: T::npy_data_type(),
543+
to: U::npy_data_type(),
544+
})
545+
} else {
546+
unsafe { Ok(PyArray::<U>::from_owned_ptr(py, ptr)) }
547+
}
548+
}
458549
}

src/error.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ pub enum ArrayCastError {
2828
ToRust { from: NpyDataType, to: NpyDataType },
2929
/// Error for casting rust's `Vec` into numpy array.
3030
FromVec,
31+
/// Error in numpy -> numpy data conversion
32+
Numpy { from: NpyDataType, to: NpyDataType },
3133
}
3234

3335
impl ArrayCastError {
@@ -46,18 +48,16 @@ impl fmt::Display for ArrayCastError {
4648
write!(f, "Cast failed: from={:?}, to={:?}", from, to)
4749
}
4850
ArrayCastError::FromVec => write!(f, "Cast failed: FromVec (maybe invalid dimension)"),
51+
ArrayCastError::Numpy { from, to } => write!(
52+
f,
53+
"Cast failed: from=ndarray(dtype={:?}), to=ndarray(dtype={:?})",
54+
from, to
55+
),
4956
}
5057
}
5158
}
5259

53-
impl error::Error for ArrayCastError {
54-
fn description(&self) -> &str {
55-
match self {
56-
ArrayCastError::ToRust { .. } => "ArrayCast failed(IntoArray)",
57-
ArrayCastError::FromVec => "ArrayCast failed(FromVec)",
58-
}
59-
}
60-
}
60+
impl error::Error for ArrayCastError {}
6161

6262
impl IntoPyErr for ArrayCastError {
6363
fn into_pyerr(self, msg: &str) -> PyErr {
@@ -67,6 +67,10 @@ impl IntoPyErr for ArrayCastError {
6767
from, to, msg
6868
),
6969
ArrayCastError::FromVec => format!("ArrayCastError::FromVec: {}", msg),
70+
ArrayCastError::Numpy { from, to } => format!(
71+
"ArrayCastError::Numpy: from: {:?}, to: {:?}, msg: {}",
72+
from, to, msg
73+
),
7074
};
7175
PyErr::new::<exc::TypeError, _>(msg)
7276
}

tests/array.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,17 @@ macro_rules! small_array_test {
181181
}
182182

183183
small_array_test!(i8 u8 i16 u16 i32 u32 i64 u64);
184+
185+
#[test]
186+
fn array_cast() {
187+
let gil = pyo3::Python::acquire_gil();
188+
let py = gil.python();
189+
let np = PyArrayModule::import(py).unwrap();
190+
let vec2 = vec![vec![1.0, 2.0, 3.0]; 2];
191+
let arr_f64 = PyArray::from_vec2(gil.python(), &np, &vec2).unwrap();
192+
let arr_i32: PyArray<i32> = arr_f64.cast(py, &np, false).unwrap();
193+
assert_eq!(
194+
arr_i32.as_array().unwrap(),
195+
array![[1, 2, 3], [1, 2, 3]].into_dyn()
196+
);
197+
}

0 commit comments

Comments
 (0)