Skip to content

Commit 5a77972

Browse files
committed
Implement IntoPyArray for ndarray::ArrayBase<OwnedRepr..>
1 parent 9d48ef4 commit 5a77972

File tree

4 files changed

+82
-29
lines changed

4 files changed

+82
-29
lines changed

src/array.rs

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::mem;
99
use std::os::raw::c_int;
1010
use std::ptr;
1111

12-
use convert::{NpyIndex, ToNpyDims};
12+
use convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
1313
use error::{ErrorKind, IntoPyResult};
1414
use slice_box::SliceBox;
1515
use types::{NpyDataType, TypeNum};
@@ -265,6 +265,10 @@ impl<T, D> PyArray<T, D> {
265265
let ptr = self.as_array_ptr();
266266
(*ptr).data as *mut T
267267
}
268+
269+
pub(crate) unsafe fn copy_ptr(&self, other: *const T, len: usize) {
270+
ptr::copy_nonoverlapping(other, self.data(), len)
271+
}
268272
}
269273

270274
impl<T: TypeNum, D: Dimension> PyArray<T, D> {
@@ -306,7 +310,7 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
306310
unsafe { PyArray::new_(py, dims, ptr::null_mut(), flags) }
307311
}
308312

309-
unsafe fn new_<'py, ID>(
313+
pub(crate) unsafe fn new_<'py, ID>(
310314
py: Python<'py>,
311315
dims: ID,
312316
strides: *mut npy_intp,
@@ -330,16 +334,17 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
330334
Self::from_owned_ptr(py, ptr)
331335
}
332336

333-
pub(crate) unsafe fn new_with_data<'py, ID>(
337+
pub(crate) unsafe fn from_boxed_slice<'py, ID>(
334338
py: Python<'py>,
335339
dims: ID,
336340
strides: *mut npy_intp,
337-
slice: &SliceBox<T>,
341+
slice: Box<[T]>,
338342
) -> &'py Self
339343
where
340344
ID: IntoDimension<Dim = D>,
341345
{
342346
let dims = dims.into_dimension();
347+
let slice = SliceBox::new(slice);
343348
let ptr = PY_ARRAY_API.PyArray_New(
344349
PY_ARRAY_API.get_type_object(npyffi::ArrayType::PyArray_Type),
345350
dims.ndim_cint(),
@@ -388,7 +393,7 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
388393
}
389394
}
390395

391-
/// Construct PyArray from ndarray::Array.
396+
/// Construct PyArray from `ndarray::ArrayBase`.
392397
///
393398
/// This method allocates memory in Python's heap via numpy api, and then copies all elements
394399
/// of the array there.
@@ -398,25 +403,32 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
398403
/// # extern crate pyo3; extern crate numpy; #[macro_use] extern crate ndarray; fn main() {
399404
/// use numpy::PyArray;
400405
/// let gil = pyo3::Python::acquire_gil();
401-
/// let pyarray = PyArray::from_ndarray(gil.python(), &array![[1, 2], [3, 4]]);
406+
/// let pyarray = PyArray::from_array(gil.python(), &array![[1, 2], [3, 4]]);
402407
/// assert_eq!(pyarray.as_array().unwrap(), array![[1, 2], [3, 4]]);
403408
/// # }
404409
/// ```
405-
pub fn from_ndarray<'py, S>(py: Python<'py>, arr: &ArrayBase<S, D>) -> &'py Self
410+
pub fn from_array<'py, S>(py: Python<'py>, arr: &ArrayBase<S, D>) -> &'py Self
406411
where
407412
S: Data<Elem = T>,
408413
{
409-
let len = arr.len();
410-
let mut strides: Vec<_> = arr
411-
.strides()
412-
.into_iter()
413-
.map(|n| n * mem::size_of::<T>() as npy_intp)
414-
.collect();
415-
unsafe {
416-
let array = PyArray::new_(py, arr.raw_dim(), strides.as_mut_ptr() as *mut npy_intp, 0);
417-
ptr::copy_nonoverlapping(arr.as_ptr(), array.data(), len);
418-
array
419-
}
414+
ToPyArray::to_pyarray(arr, py)
415+
}
416+
417+
/// Construct PyArray from `ndarray::Array`.
418+
///
419+
/// This method uses internal `Vec` of `ndarray::Array` as numpy array.
420+
///
421+
/// # Example
422+
/// ```
423+
/// # extern crate pyo3; extern crate numpy; #[macro_use] extern crate ndarray; fn main() {
424+
/// use numpy::PyArray;
425+
/// let gil = pyo3::Python::acquire_gil();
426+
/// let pyarray = PyArray::from_owned_array(gil.python(), array![[1, 2], [3, 4]]);
427+
/// assert_eq!(pyarray.as_array().unwrap(), array![[1, 2], [3, 4]]);
428+
/// # }
429+
/// ```
430+
pub fn from_owned_array<'py>(py: Python<'py>, arr: Array<T, D>) -> &'py Self {
431+
IntoPyArray::into_pyarray(arr, py)
420432
}
421433

422434
/// Get the immutable view of the internal data of `PyArray`, as `ndarray::ArrayView`.
@@ -565,8 +577,7 @@ impl<T: TypeNum> PyArray<T, Ix1> {
565577
pub fn from_slice<'py>(py: Python<'py>, slice: &[T]) -> &'py Self {
566578
let array = PyArray::new(py, [slice.len()], false);
567579
unsafe {
568-
let src = slice.as_ptr() as *mut T;
569-
ptr::copy_nonoverlapping(src, array.data(), slice.len());
580+
array.copy_ptr(slice.as_ptr(), slice.len());
570581
}
571582
array
572583
}

src/convert.rs

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
//! Defines conversion traits between rust types and numpy data types.
22
3-
use ndarray::{ArrayBase, Data, Dimension, IntoDimension, Ix1};
3+
use ndarray::{ArrayBase, Data, Dimension, IntoDimension, Ix1, OwnedRepr};
44
use pyo3::Python;
5-
use slice_box::SliceBox;
65

76
use std::mem;
87
use std::os::raw::c_int;
98
use std::ptr;
109

1110
use super::*;
11+
use npyffi::npy_intp;
1212

1313
/// Covnersion trait from some rust types to `PyArray`.
1414
///
@@ -34,10 +34,7 @@ impl<T: TypeNum> IntoPyArray for Box<[T]> {
3434
type Dim = Ix1;
3535
fn into_pyarray<'py>(self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
3636
let len = self.len();
37-
unsafe {
38-
let slice = SliceBox::new(self);
39-
PyArray::new_with_data(py, [len], ptr::null_mut(), slice)
40-
}
37+
unsafe { PyArray::from_boxed_slice(py, [len], ptr::null_mut(), self) }
4138
}
4239
}
4340

@@ -49,6 +46,21 @@ impl<T: TypeNum> IntoPyArray for Vec<T> {
4946
}
5047
}
5148

49+
impl<A, D> IntoPyArray for ArrayBase<OwnedRepr<A>, D>
50+
where
51+
A: TypeNum,
52+
D: Dimension,
53+
{
54+
type Item = A;
55+
type Dim = D;
56+
fn into_pyarray<'py>(self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
57+
let mut strides = npy_strides(&self);
58+
let dim = self.raw_dim();
59+
let boxed = self.into_raw_vec().into_boxed_slice();
60+
unsafe { PyArray::from_boxed_slice(py, dim, strides.as_mut_ptr() as *mut npy_intp, boxed) }
61+
}
62+
}
63+
5264
/// Conversion trait from rust types to `PyArray`.
5365
///
5466
/// This trait takes `&self`, which means **it alocates in Python heap and then copies
@@ -85,10 +97,29 @@ where
8597
type Item = A;
8698
type Dim = D;
8799
fn to_pyarray<'py>(&self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
88-
PyArray::from_ndarray(py, self)
100+
let len = self.len();
101+
let mut strides = npy_strides(self);
102+
unsafe {
103+
let array = PyArray::new_(py, self.raw_dim(), strides.as_mut_ptr() as *mut npy_intp, 0);
104+
array.copy_ptr(self.as_ptr(), len);
105+
array
106+
}
89107
}
90108
}
91109

110+
fn npy_strides<S, D, A>(array: &ArrayBase<S, D>) -> Vec<npyffi::npy_intp>
111+
where
112+
S: Data<Elem = A>,
113+
D: Dimension,
114+
A: TypeNum,
115+
{
116+
array
117+
.strides()
118+
.into_iter()
119+
.map(|n| n * mem::size_of::<A>() as npyffi::npy_intp)
120+
.collect()
121+
}
122+
92123
/// Utility trait to specify the dimention of array
93124
pub trait ToNpyDims: Dimension {
94125
fn ndim_cint(&self) -> c_int {

src/slice_box.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ impl<T> PyObjectAlloc<SliceBox<T>> for SliceBox<T> {
6868
/// Calls the rust destructor for the object.
6969
unsafe fn drop(py: Python, obj: *mut ffi::PyObject) {
7070
let data = (*(obj as *mut SliceBox<T>)).inner;
71-
let box_ = Box::from_raw(data);
72-
drop(box_);
71+
let boxed_slice = Box::from_raw(data);
72+
drop(boxed_slice);
7373
<Self as typeob::PyTypeInfo>::BaseType::drop(py, obj);
7474
}
7575
unsafe fn dealloc(py: Python, obj: *mut ffi::PyObject) {

tests/array.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,14 @@ fn into_pyarray_vec() {
214214
let arr = a.into_pyarray(gil.python());
215215
assert_eq!(arr.as_slice().unwrap(), &[1, 2, 3])
216216
}
217+
218+
#[test]
219+
fn into_pyarray_array() {
220+
let gil = pyo3::Python::acquire_gil();
221+
let arr = Array3::<f64>::zeros((3, 4, 2));
222+
let shape = arr.shape().iter().cloned().collect::<Vec<_>>();
223+
let strides = arr.strides().iter().map(|d| d * 8).collect::<Vec<_>>();
224+
let py_arr = arr.into_pyarray(gil.python());
225+
assert_eq!(py_arr.shape(), shape.as_slice());
226+
assert_eq!(py_arr.strides(), strides.as_slice());
227+
}

0 commit comments

Comments
 (0)