Skip to content

Commit 2ed632a

Browse files
committed
Re-Implement PyArray::from_iter and PyArray::from_exact_iter
1 parent fb71b3e commit 2ed632a

File tree

2 files changed

+93
-23
lines changed

2 files changed

+93
-23
lines changed

src/array.rs

Lines changed: 83 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
//! Safe interface for NumPy ndarray
2-
32
use ndarray::*;
43
use npyffi::{self, npy_intp, PY_ARRAY_API};
54
use pyo3::*;
@@ -111,7 +110,7 @@ impl<T> PyArray<T> {
111110
/// # extern crate pyo3; extern crate numpy; fn main() {
112111
/// use numpy::PyArray;
113112
/// let gil = pyo3::Python::acquire_gil();
114-
/// let arr = PyArray::<f64>::new(gil.python(), [4, 5, 6]);
113+
/// let arr = PyArray::<f64>::new(gil.python(), [4, 5, 6], false);
115114
/// assert_eq!(arr.ndim(), 3);
116115
/// # }
117116
/// ```
@@ -129,7 +128,7 @@ impl<T> PyArray<T> {
129128
/// # extern crate pyo3; extern crate numpy; fn main() {
130129
/// use numpy::PyArray;
131130
/// let gil = pyo3::Python::acquire_gil();
132-
/// let arr = PyArray::<f64>::new(gil.python(), [4, 5, 6]);
131+
/// let arr = PyArray::<f64>::new(gil.python(), [4, 5, 6], false);
133132
/// assert_eq!(arr.strides(), &[240, 48, 8]);
134133
/// # }
135134
/// ```
@@ -151,7 +150,7 @@ impl<T> PyArray<T> {
151150
/// # extern crate pyo3; extern crate numpy; fn main() {
152151
/// use numpy::PyArray;
153152
/// let gil = pyo3::Python::acquire_gil();
154-
/// let arr = PyArray::<f64>::new(gil.python(), [4, 5, 6]);
153+
/// let arr = PyArray::<f64>::new(gil.python(), [4, 5, 6], false);
155154
/// assert_eq!(arr.shape(), &[4, 5, 6]);
156155
/// # }
157156
/// ```
@@ -201,33 +200,33 @@ impl<T> PyArray<T> {
201200
}
202201

203202
// TODO: we should provide safe access API
204-
unsafe fn access(&self, index: &[isize]) -> *const T {
205-
let strides = self.strides();
206-
let mut start = self.data();
203+
unsafe fn get_unchecked(&self, index: &[isize]) -> *const T {
207204
let size = mem::size_of::<T>() as isize;
208-
for (i, idx) in index.iter().enumerate() {
209-
start = start.offset(strides[i] * idx / size);
210-
}
211-
start
205+
index
206+
.iter()
207+
.zip(self.strides())
208+
.fold(self.data(), |pointer, (idx, stride)| {
209+
pointer.offset(stride * idx / size)
210+
})
212211
}
213212

214213
// TODO: we should provide safe access API
215214
#[inline(always)]
216-
unsafe fn access_mut(&self, index: &[isize]) -> *mut T {
217-
self.access(index) as *mut T
215+
unsafe fn get_unchecked_mut(&self, index: &[isize]) -> *mut T {
216+
self.get_unchecked(index) as *mut T
218217
}
219218
}
220219

221220
impl<T: TypeNum> PyArray<T> {
222-
/// Construct one-dimension PyArray from boxed slice.
221+
/// Construct one-dimension PyArray from slice.
223222
///
224223
/// # Example
225224
/// ```
226225
/// # extern crate pyo3; extern crate numpy; fn main() {
227226
/// use numpy::PyArray;
228227
/// let gil = pyo3::Python::acquire_gil();
229-
/// let slice = vec![1, 2, 3, 4, 5].into_boxed_slice();
230-
/// let pyarray = PyArray::from_boxed_slice(gil.python(), slice);
228+
/// let array = [1, 2, 3, 4, 5];
229+
/// let pyarray = PyArray::from_slice(gil.python(), &array);
231230
/// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]);
232231
/// # }
233232
/// ```
@@ -240,8 +239,33 @@ impl<T: TypeNum> PyArray<T> {
240239
array
241240
}
242241

242+
/// Construct one-dimension PyArray from `impl ExactSizeIterator`.
243+
///
244+
/// # Example
245+
/// ```
246+
/// # extern crate pyo3; extern crate numpy; fn main() {
247+
/// use numpy::PyArray;
248+
/// use std::collections::BTreeSet;
249+
/// let gil = pyo3::Python::acquire_gil();
250+
/// let vec = vec![1, 2, 3, 4, 5];
251+
/// let pyarray = PyArray::from_iter(gil.python(), vec.iter().map(|&x| x));
252+
/// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]);
253+
/// # }
254+
/// ```
255+
pub fn from_exact_iter(py: Python, iter: impl ExactSizeIterator<Item = T>) -> &Self {
256+
let array = Self::new(py, [iter.len()], false);
257+
unsafe {
258+
for (i, item) in iter.enumerate() {
259+
*array.get_unchecked_mut(&[i as isize]) = item;
260+
}
261+
}
262+
array
263+
}
264+
243265
/// Construct one-dimension PyArray from `impl IntoIterator`.
244266
///
267+
/// This method can allocate multiple times and not fast.
268+
/// When you can use [from_exact_iter](method.from_exact_iter.html), please use it.
245269
/// # Example
246270
/// ```
247271
/// # extern crate pyo3; extern crate numpy; fn main() {
@@ -253,13 +277,26 @@ impl<T: TypeNum> PyArray<T> {
253277
/// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]);
254278
/// # }
255279
/// ```
256-
pub fn from_iter(py: Python, iter: impl ExactSizeIterator<Item = T>) -> &Self {
257-
let array = Self::new(py, [iter.len()], false);
280+
pub fn from_iter(py: Python, iter: impl IntoIterator<Item = T>) -> &Self {
281+
// ↓ max cached size of ndarray
282+
let mut capacity = 1024 / mem::size_of::<T>();
283+
let array = Self::new(py, [capacity], false);
284+
let mut length = 0;
258285
unsafe {
259286
for (i, item) in iter.into_iter().enumerate() {
260-
*array.access_mut(&[i as isize]) = item;
287+
length += 1;
288+
if length >= capacity {
289+
capacity *= 2;
290+
array
291+
.resize_([capacity], 0, NPY_ORDER::NPY_ANYORDER)
292+
.expect("PyArray::from_iter: Failed to allocate memory");
293+
}
294+
*array.get_unchecked_mut(&[i as isize]) = item;
261295
}
262296
}
297+
if capacity > length {
298+
array.resize_([length], 0, NPY_ORDER::NPY_ANYORDER).unwrap()
299+
}
263300
array
264301
}
265302

@@ -295,7 +332,7 @@ impl<T: TypeNum> PyArray<T> {
295332
unsafe {
296333
for y in 0..v.len() {
297334
for x in 0..last_len {
298-
*array.access_mut(&[y as isize, x as isize]) = v[y][x].clone();
335+
*array.get_unchecked_mut(&[y as isize, x as isize]) = v[y][x].clone();
299336
}
300337
}
301338
}
@@ -348,7 +385,7 @@ impl<T: TypeNum> PyArray<T> {
348385
for z in 0..v.len() {
349386
for y in 0..dim2 {
350387
for x in 0..dim3 {
351-
*array.access_mut(&[z as isize, y as isize, x as isize]) =
388+
*array.get_unchecked_mut(&[z as isize, y as isize, x as isize]) =
352389
v[z][y][x].clone();
353390
}
354391
}
@@ -633,13 +670,36 @@ impl<T: TypeNum> PyArray<T> {
633670
Ok(unsafe { PyArray::<T>::from_owned_ptr(self.py(), ptr) })
634671
}
635672
}
673+
674+
// TODO: expose?
675+
fn resize_<'py, D: ToNpyDims>(
676+
&'py self,
677+
dims: D,
678+
check_ref: c_int,
679+
order: NPY_ORDER,
680+
) -> Result<(), ErrorKind> {
681+
let mut np_dims = dims.to_npy_dims();
682+
let res = unsafe {
683+
PY_ARRAY_API.PyArray_Resize(
684+
self.as_array_ptr(),
685+
&mut np_dims as *mut npyffi::PyArray_Dims,
686+
check_ref,
687+
order,
688+
)
689+
};
690+
if res.is_null() {
691+
Err(ErrorKind::dims_cast(self, dims))
692+
} else {
693+
Ok(())
694+
}
695+
}
636696
}
637697

638698
#[test]
639-
fn test_access() {
699+
fn test_get_unchecked() {
640700
let gil = pyo3::Python::acquire_gil();
641701
let array = PyArray::from_slice(gil.python(), &[1i32, 2, 3]);
642702
unsafe {
643-
assert_eq!(*array.access(&[1]), 2);
703+
assert_eq!(*array.get_unchecked(&[1]), 2);
644704
}
645705
}

tests/array.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,16 @@ fn iter_to_pyarray() {
102102
);
103103
}
104104

105+
#[test]
106+
fn long_iter_to_pyarray() {
107+
let gil = pyo3::Python::acquire_gil();
108+
let arr = PyArray::from_iter(gil.python(), (0u32..512).map(|x| x));
109+
let slice = arr.as_slice().unwrap();
110+
for (i, &elem) in slice.iter().enumerate() {
111+
assert_eq!(i as u32, elem);
112+
}
113+
}
114+
105115
#[test]
106116
fn is_instance() {
107117
let gil = pyo3::Python::acquire_gil();

0 commit comments

Comments
 (0)