Skip to content

Commit cd3f156

Browse files
committed
Implement PyArray::copy_to, move_to
1 parent 0b941e4 commit cd3f156

File tree

2 files changed

+72
-8
lines changed

2 files changed

+72
-8
lines changed

src/array.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,4 +455,64 @@ impl<T: TypeNum> PyArray<T> {
455455
Self::from_owned_ptr(py, ptr)
456456
}
457457
}
458+
459+
/// Copies self into `other`, performing a data-type conversion if necessary.
460+
/// # Example
461+
/// ```
462+
/// # extern crate pyo3; extern crate numpy; fn main() {
463+
/// use numpy::{PyArray, PyArrayModule, IntoPyArray};
464+
/// let gil = pyo3::Python::acquire_gil();
465+
/// let np = PyArrayModule::import(gil.python()).unwrap();
466+
/// let pyarray_f = PyArray::<f64>::arange(gil.python(), &np, 2.0, 5.0, 1.0);
467+
/// let mut pyarray_i = PyArray::<i64>::new(gil.python(), &np, &[3]);
468+
/// assert!(pyarray_f.copy_to(&np, &mut pyarray_i).is_ok());
469+
/// assert_eq!(pyarray_i.as_slice().unwrap(), &[2, 3, 4]);
470+
/// # }
471+
pub fn copy_to<U: TypeNum>(
472+
&self,
473+
np: &PyArrayModule,
474+
other: &mut PyArray<U>,
475+
) -> Result<(), ArrayCastError> {
476+
let self_ptr = self.as_array_ptr();
477+
let other_ptr = other.as_array_ptr();
478+
let result = unsafe { np.PyArray_CopyInto(other_ptr, self_ptr) };
479+
if result == -1 {
480+
Err(ArrayCastError::Numpy {
481+
from: T::npy_data_type(),
482+
to: U::npy_data_type(),
483+
})
484+
} else {
485+
Ok(())
486+
}
487+
}
488+
489+
/// Move the data of self into `other`, performing a data-type conversion if necessary.
490+
/// # Example
491+
/// ```
492+
/// # extern crate pyo3; extern crate numpy; fn main() {
493+
/// use numpy::{PyArray, PyArrayModule, IntoPyArray};
494+
/// let gil = pyo3::Python::acquire_gil();
495+
/// let np = PyArrayModule::import(gil.python()).unwrap();
496+
/// let pyarray_f = PyArray::<f64>::arange(gil.python(), &np, 2.0, 5.0, 1.0);
497+
/// let mut pyarray_i = PyArray::<i64>::new(gil.python(), &np, &[3]);
498+
/// assert!(pyarray_f.move_to(&np, &mut pyarray_i).is_ok());
499+
/// assert_eq!(pyarray_i.as_slice().unwrap(), &[2, 3, 4]);
500+
/// # }
501+
pub fn move_to<U: TypeNum>(
502+
self,
503+
np: &PyArrayModule,
504+
other: &mut PyArray<U>,
505+
) -> Result<(), ArrayCastError> {
506+
let self_ptr = self.as_array_ptr();
507+
let other_ptr = other.as_array_ptr();
508+
let result = unsafe { np.PyArray_MoveInto(other_ptr, self_ptr) };
509+
if result == -1 {
510+
Err(ArrayCastError::Numpy {
511+
from: T::npy_data_type(),
512+
to: U::npy_data_type(),
513+
})
514+
} else {
515+
Ok(())
516+
}
517+
}
458518
}

src/error.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ pub enum ArrayCastError {
2828
ToRust { from: NpyDataType, to: NpyDataType },
2929
/// Error for casting rust's `Vec` into numpy array.
3030
FromVec,
31+
/// Error in numpy -> numpy data conversion
32+
Numpy { from: NpyDataType, to: NpyDataType },
3133
}
3234

3335
impl ArrayCastError {
@@ -46,18 +48,16 @@ impl fmt::Display for ArrayCastError {
4648
write!(f, "Cast failed: from={:?}, to={:?}", from, to)
4749
}
4850
ArrayCastError::FromVec => write!(f, "Cast failed: FromVec (maybe invalid dimension)"),
51+
ArrayCastError::Numpy { from, to } => write!(
52+
f,
53+
"Cast failed: from=ndarray(dtype={:?}), to=ndarray(dtype={:?})",
54+
from, to
55+
),
4956
}
5057
}
5158
}
5259

53-
impl error::Error for ArrayCastError {
54-
fn description(&self) -> &str {
55-
match self {
56-
ArrayCastError::ToRust { .. } => "ArrayCast failed(IntoArray)",
57-
ArrayCastError::FromVec => "ArrayCast failed(FromVec)",
58-
}
59-
}
60-
}
60+
impl error::Error for ArrayCastError {}
6161

6262
impl IntoPyErr for ArrayCastError {
6363
fn into_pyerr(self, msg: &str) -> PyErr {
@@ -67,6 +67,10 @@ impl IntoPyErr for ArrayCastError {
6767
from, to, msg
6868
),
6969
ArrayCastError::FromVec => format!("ArrayCastError::FromVec: {}", msg),
70+
ArrayCastError::Numpy { from, to } => format!(
71+
"ArrayCastError::Numpy: from: {:?}, to: {:?}, msg: {}",
72+
from, to, msg
73+
),
7074
};
7175
PyErr::new::<exc::TypeError, _>(msg)
7276
}

0 commit comments

Comments
 (0)