Skip to content

Commit fb71b3e

Browse files
committed
Remove IntoPyObject in favour of ToPyObject(&self, py) -> &PyArray
1 parent e372f0b commit fb71b3e

File tree

4 files changed

+167
-151
lines changed

4 files changed

+167
-151
lines changed

src/array.rs

Lines changed: 130 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
//! Safe interface for NumPy ndarray
22
33
use ndarray::*;
4-
use npyffi::{self, PY_ARRAY_API};
4+
use npyffi::{self, npy_intp, PY_ARRAY_API};
55
use pyo3::*;
6+
use std::iter::ExactSizeIterator;
67
use std::marker::PhantomData;
7-
use std::os::raw::c_void;
8-
use std::ptr::null_mut;
8+
use std::mem;
9+
use std::os::raw::c_int;
10+
use std::ptr;
911

10-
use super::error::ErrorKind;
11-
use super::*;
12+
use convert::ToNpyDims;
13+
use error::{ErrorKind, IntoPyErr};
14+
use types::{NpyDataType, TypeNum, NPY_ORDER};
1215

1316
/// Interface for [NumPy ndarray](https://docs.scipy.org/doc/numpy/reference/arrays.ndarray.html).
1417
pub struct PyArray<T>(PyObject, PhantomData<T>);
@@ -172,6 +175,47 @@ impl<T> PyArray<T> {
172175
pub fn len(&self) -> usize {
173176
self.shape().iter().fold(1, |a, b| a * b)
174177
}
178+
179+
fn ndarray_shape(&self) -> StrideShape<IxDyn> {
180+
// FIXME may be done more simply
181+
let shape: Shape<_> = Dim(self.shape()).into();
182+
let st: Vec<usize> = self
183+
.strides()
184+
.iter()
185+
.map(|&x| x as usize / ::std::mem::size_of::<T>())
186+
.collect();
187+
shape.strides(Dim(st))
188+
}
189+
190+
fn typenum(&self) -> i32 {
191+
unsafe {
192+
let descr = (*self.as_array_ptr()).descr;
193+
(*descr).type_num
194+
}
195+
}
196+
197+
/// Returns the pointer to the first element of the inner array.
198+
unsafe fn data(&self) -> *mut T {
199+
let ptr = self.as_array_ptr();
200+
(*ptr).data as *mut T
201+
}
202+
203+
// TODO: we should provide safe access API
204+
unsafe fn access(&self, index: &[isize]) -> *const T {
205+
let strides = self.strides();
206+
let mut start = self.data();
207+
let size = mem::size_of::<T>() as isize;
208+
for (i, idx) in index.iter().enumerate() {
209+
start = start.offset(strides[i] * idx / size);
210+
}
211+
start
212+
}
213+
214+
// TODO: we should provide safe access API
215+
#[inline(always)]
216+
unsafe fn access_mut(&self, index: &[isize]) -> *mut T {
217+
self.access(index) as *mut T
218+
}
175219
}
176220

177221
impl<T: TypeNum> PyArray<T> {
@@ -187,8 +231,13 @@ impl<T: TypeNum> PyArray<T> {
187231
/// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]);
188232
/// # }
189233
/// ```
190-
pub fn from_boxed_slice(py: Python, v: Box<[T]>) -> &Self {
191-
IntoPyArray::into_pyarray(v, py)
234+
pub fn from_slice<'py>(py: Python<'py>, slice: &[T]) -> &'py Self {
235+
let array = PyArray::new(py, [slice.len()], false);
236+
unsafe {
237+
let src = slice.as_ptr() as *mut T;
238+
ptr::copy_nonoverlapping(src, array.data(), slice.len());
239+
}
240+
array
192241
}
193242

194243
/// Construct one-dimension PyArray from `impl IntoIterator`.
@@ -204,23 +253,14 @@ impl<T: TypeNum> PyArray<T> {
204253
/// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]);
205254
/// # }
206255
/// ```
207-
pub fn from_iter(py: Python, i: impl IntoIterator<Item = T>) -> &Self {
208-
i.into_iter().collect::<Vec<_>>().into_pyarray(py)
209-
}
210-
211-
/// Construct one-dimension PyArray from Vec.
212-
///
213-
/// # Example
214-
/// ```
215-
/// # extern crate pyo3; extern crate numpy; fn main() {
216-
/// use numpy::PyArray;
217-
/// let gil = pyo3::Python::acquire_gil();
218-
/// let pyarray = PyArray::from_vec(gil.python(), vec![1, 2, 3, 4, 5]);
219-
/// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]);
220-
/// # }
221-
/// ```
222-
pub fn from_vec(py: Python, v: Vec<T>) -> &Self {
223-
IntoPyArray::into_pyarray(v, py)
256+
pub fn from_iter(py: Python, iter: impl ExactSizeIterator<Item = T>) -> &Self {
257+
let array = Self::new(py, [iter.len()], false);
258+
unsafe {
259+
for (i, item) in iter.into_iter().enumerate() {
260+
*array.access_mut(&[i as isize]) = item;
261+
}
262+
}
263+
array
224264
}
225265

226266
/// Construct a two-dimension PyArray from `Vec<Vec<T>>`.
@@ -239,7 +279,10 @@ impl<T: TypeNum> PyArray<T> {
239279
/// assert!(PyArray::from_vec2(gil.python(), &vec![vec![1], vec![2, 3]]).is_err());
240280
/// # }
241281
/// ```
242-
pub fn from_vec2<'py>(py: Python<'py>, v: &Vec<Vec<T>>) -> Result<&'py Self, ErrorKind> {
282+
pub fn from_vec2<'py>(py: Python<'py>, v: &Vec<Vec<T>>) -> Result<&'py Self, ErrorKind>
283+
where
284+
T: Clone,
285+
{
243286
let last_len = v.last().map_or(0, |v| v.len());
244287
if v.iter().any(|v| v.len() != last_len) {
245288
return Err(ErrorKind::FromVec {
@@ -248,11 +291,15 @@ impl<T: TypeNum> PyArray<T> {
248291
});
249292
}
250293
let dims = [v.len(), last_len];
251-
let flattend: Vec<_> = v.iter().cloned().flatten().collect();
294+
let array = Self::new(py, dims, false);
252295
unsafe {
253-
let data = convert::into_raw(flattend);
254-
Ok(PyArray::new_(py, dims, null_mut(), data))
296+
for y in 0..v.len() {
297+
for x in 0..last_len {
298+
*array.access_mut(&[y as isize, x as isize]) = v[y][x].clone();
299+
}
300+
}
255301
}
302+
Ok(array)
256303
}
257304

258305
/// Construct a three-dimension PyArray from `Vec<Vec<Vec<T>>>`.
@@ -277,7 +324,10 @@ impl<T: TypeNum> PyArray<T> {
277324
pub fn from_vec3<'py>(
278325
py: Python<'py>,
279326
v: &Vec<Vec<Vec<T>>>,
280-
) -> Result<&'py PyArray<T>, ErrorKind> {
327+
) -> Result<&'py PyArray<T>, ErrorKind>
328+
where
329+
T: Clone,
330+
{
281331
let dim2 = v.last().map_or(0, |v| v.len());
282332
if v.iter().any(|v| v.len() != dim2) {
283333
return Err(ErrorKind::FromVec {
@@ -293,11 +343,18 @@ impl<T: TypeNum> PyArray<T> {
293343
});
294344
}
295345
let dims = [v.len(), dim2, dim3];
296-
let flattend: Vec<_> = v.iter().flat_map(|v| v.iter().cloned().flatten()).collect();
346+
let array = Self::new(py, dims, false);
297347
unsafe {
298-
let data = convert::into_raw(flattend);
299-
Ok(PyArray::new_(py, dims, null_mut(), data))
348+
for z in 0..v.len() {
349+
for y in 0..dim2 {
350+
for x in 0..dim3 {
351+
*array.access_mut(&[z as isize, y as isize, x as isize]) =
352+
v[z][y][x].clone();
353+
}
354+
}
355+
}
300356
}
357+
Ok(array)
301358
}
302359

303360
/// Construct PyArray from ndarray::Array.
@@ -311,34 +368,22 @@ impl<T: TypeNum> PyArray<T> {
311368
/// assert_eq!(pyarray.as_array().unwrap(), array![[1, 2], [3, 4]].into_dyn());
312369
/// # }
313370
/// ```
314-
pub fn from_ndarray<D>(py: Python, arr: Array<T, D>) -> &Self
371+
pub fn from_ndarray<'py, S, D>(py: Python<'py>, arr: &ArrayBase<S, D>) -> &'py Self
315372
where
373+
S: Data<Elem = T>,
316374
D: Dimension,
317375
{
318-
IntoPyArray::into_pyarray(arr, py)
319-
}
320-
321-
/// Returns the pointer to the first element of the inner array.
322-
unsafe fn data(&self) -> *mut T {
323-
let ptr = self.as_array_ptr();
324-
(*ptr).data as *mut T
325-
}
326-
327-
fn ndarray_shape(&self) -> StrideShape<IxDyn> {
328-
// FIXME may be done more simply
329-
let shape: Shape<_> = Dim(self.shape()).into();
330-
let st: Vec<usize> = self
376+
let dims: Vec<_> = arr.shape().iter().cloned().collect();
377+
let len = dims.iter().fold(1, |prod, d| prod * d);
378+
let mut strides: Vec<_> = arr
331379
.strides()
332-
.iter()
333-
.map(|&x| x as usize / ::std::mem::size_of::<T>())
380+
.into_iter()
381+
.map(|n| n * mem::size_of::<T>() as npy_intp)
334382
.collect();
335-
shape.strides(Dim(st))
336-
}
337-
338-
fn typenum(&self) -> i32 {
339383
unsafe {
340-
let descr = (*self.as_array_ptr()).descr;
341-
(*descr).type_num
384+
let array = PyArray::new_(py, &*dims, strides.as_mut_ptr() as *mut npy_intp, 0);
385+
ptr::copy_nonoverlapping(arr.as_ptr(), array.data(), len);
386+
array
342387
}
343388
}
344389

@@ -385,46 +430,44 @@ impl<T: TypeNum> PyArray<T> {
385430
unsafe { Ok(::std::slice::from_raw_parts_mut(self.data(), self.len())) }
386431
}
387432

388-
/// Construct a new PyArray given a raw pointer and dimensions.
433+
/// Creates a new uninitialized PyArray in python heap.
434+
///
435+
/// See also [PyArray_SimpleNew](https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_SimpleNew).
389436
///
390-
/// Please use `new` or from methods instead.
391-
pub unsafe fn new_<'py, D: ToNpyDims>(
437+
/// # Example
438+
/// ```
439+
/// # extern crate pyo3; extern crate numpy; #[macro_use] extern crate ndarray; fn main() {
440+
/// use numpy::PyArray;
441+
/// let gil = pyo3::Python::acquire_gil();
442+
/// let pyarray = PyArray::<i32>::new(gil.python(), [4, 5, 6]);
443+
/// assert_eq!(pyarray.shape(), &[4, 5, 6]);
444+
/// # }
445+
/// ```
446+
pub fn new<'py, D: ToNpyDims>(py: Python<'py>, dims: D, is_fortran: bool) -> &'py Self {
447+
let flags = if is_fortran { 1 } else { 0 };
448+
unsafe { PyArray::new_(py, dims, ptr::null_mut(), flags) }
449+
}
450+
451+
unsafe fn new_<'py, D: ToNpyDims>(
392452
py: Python<'py>,
393453
dims: D,
394454
strides: *mut npy_intp,
395-
data: *mut c_void,
455+
flag: c_int,
396456
) -> &'py Self {
397457
let ptr = PY_ARRAY_API.PyArray_New(
398458
PY_ARRAY_API.get_type_object(npyffi::ArrayType::PyArray_Type),
399459
dims.dims_len(),
400460
dims.dims_ptr(),
401461
T::typenum_default(),
402-
strides,
403-
data,
462+
strides, // strides
463+
ptr::null_mut(), // data
404464
0, // itemsize
405-
0, // flag
465+
flag, // flag
406466
::std::ptr::null_mut(), //obj
407467
);
408468
Self::from_owned_ptr(py, ptr)
409469
}
410470

411-
/// Creates a new uninitialized array.
412-
///
413-
/// See also [PyArray_SimpleNew](https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_SimpleNew).
414-
///
415-
/// # Example
416-
/// ```
417-
/// # extern crate pyo3; extern crate numpy; #[macro_use] extern crate ndarray; fn main() {
418-
/// use numpy::PyArray;
419-
/// let gil = pyo3::Python::acquire_gil();
420-
/// let pyarray = PyArray::<i32>::new(gil.python(), [4, 5, 6]);
421-
/// assert_eq!(pyarray.shape(), &[4, 5, 6]);
422-
/// # }
423-
/// ```
424-
pub fn new<'py, D: ToNpyDims>(py: Python<'py>, dims: D) -> &'py Self {
425-
unsafe { Self::new_(py, dims, null_mut(), null_mut()) }
426-
}
427-
428471
/// Construct a new nd-dimensional array filled with 0. If `is_fortran` is true, then
429472
/// a fortran order array is created, otherwise a C-order array is created.
430473
///
@@ -591,3 +634,12 @@ impl<T: TypeNum> PyArray<T> {
591634
}
592635
}
593636
}
637+
638+
#[test]
639+
fn test_access() {
640+
let gil = pyo3::Python::acquire_gil();
641+
let array = PyArray::from_slice(gil.python(), &[1i32, 2, 3]);
642+
unsafe {
643+
assert_eq!(*array.access(&[1]), 2);
644+
}
645+
}

0 commit comments

Comments
 (0)