Skip to content

Commit dc3859c

Browse files
committed
Implements intial version of NpyMultiIterArray.
1 parent 2132036 commit dc3859c

File tree

2 files changed

+278
-2
lines changed

2 files changed

+278
-2
lines changed

src/npyiter.rs

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,254 @@ impl<'py, T: 'py> std::iter::Iterator for NpyIterSingleArray<'py, T> {
158158
}
159159
}
160160
}
161+
162+
pub trait MultiIterMode {}
163+
164+
impl MultiIterMode for () {}
165+
166+
pub struct RO<S> {
167+
structure: PhantomData<S>,
168+
}
169+
170+
impl<S: MultiIterMode> MultiIterMode for RO<S> {}
171+
172+
pub struct RW<S> {
173+
structure: PhantomData<S>,
174+
}
175+
176+
impl<S: MultiIterMode> MultiIterMode for RW<S> {}
177+
178+
pub trait MultiIterModeHasManyArrays: MultiIterMode {}
179+
impl MultiIterModeHasManyArrays for RO<RO<()>> {}
180+
impl MultiIterModeHasManyArrays for RO<RW<()>> {}
181+
impl MultiIterModeHasManyArrays for RW<RO<()>> {}
182+
impl MultiIterModeHasManyArrays for RW<RW<()>> {}
183+
184+
impl<S: MultiIterModeHasManyArrays> MultiIterModeHasManyArrays for RO<S> {}
185+
impl<S: MultiIterModeHasManyArrays> MultiIterModeHasManyArrays for RW<S> {}
186+
187+
pub struct NpyMultiIterBuilder<'py, T, S: MultiIterMode> {
188+
flags: npy_uint32,
189+
opflags: Vec<npy_uint32>,
190+
arrays: Vec<&'py PyArrayDyn<T>>,
191+
structure: PhantomData<S>,
192+
}
193+
194+
impl<'py, T: TypeNum> NpyMultiIterBuilder<'py, T, ()> {
195+
pub fn new() -> Self {
196+
Self {
197+
flags: 0,
198+
opflags: Vec::new(),
199+
arrays: Vec::new(),
200+
structure: PhantomData,
201+
}
202+
}
203+
204+
pub fn set(mut self, flag: NpyIterFlag) -> Self {
205+
if flag == NpyIterFlag::ExternalLoop {
206+
// TODO: I don't want to make set fallible, but also we don't want to
207+
// support ExternalLoop yet (maybe ever?).
208+
panic!("rust-numpy does not currently support ExternalLoop access");
209+
}
210+
self.flags |= flag.to_c_enum();
211+
self
212+
}
213+
214+
pub fn unset(mut self, flag: NpyIterFlag) -> Self {
215+
self.flags &= !flag.to_c_enum();
216+
self
217+
}
218+
}
219+
220+
impl<'py, T: TypeNum, S: MultiIterMode> NpyMultiIterBuilder<'py, T, S> {
221+
pub fn add_readonly_array<D: ndarray::Dimension>(
222+
mut self,
223+
array: &'py PyArray<T, D>,
224+
) -> NpyMultiIterBuilder<'py, T, RO<S>> {
225+
self.arrays.push(array.into_dyn());
226+
self.opflags.push(NPY_ITER_READONLY);
227+
228+
NpyMultiIterBuilder {
229+
flags: self.flags,
230+
opflags: self.opflags,
231+
arrays: self.arrays,
232+
structure: PhantomData,
233+
}
234+
}
235+
236+
pub fn add_readwrite_array<D: ndarray::Dimension>(
237+
mut self,
238+
array: &'py PyArray<T, D>,
239+
) -> NpyMultiIterBuilder<'py, T, RW<S>> {
240+
self.arrays.push(array.into_dyn());
241+
self.opflags.push(NPY_ITER_READWRITE);
242+
243+
NpyMultiIterBuilder {
244+
flags: self.flags,
245+
opflags: self.opflags,
246+
arrays: self.arrays,
247+
structure: PhantomData,
248+
}
249+
}
250+
}
251+
252+
impl<'py, T: TypeNum, S: MultiIterModeHasManyArrays> NpyMultiIterBuilder<'py, T, S> {
253+
pub fn build(mut self) -> PyResult<NpyMultiIterArray<'py, T, S>> {
254+
assert!(self.arrays.len() == self.opflags.len());
255+
assert!(self.arrays.len() <= i32::MAX as usize);
256+
assert!(2 <= self.arrays.len());
257+
258+
let iter_ptr = unsafe {
259+
PY_ARRAY_API.NpyIter_MultiNew(
260+
self.arrays.len() as i32,
261+
self.arrays
262+
.iter_mut()
263+
.map(|x| x.as_array_ptr())
264+
.collect::<Vec<_>>()
265+
.as_mut_ptr(),
266+
self.flags,
267+
NPY_ORDER::NPY_ANYORDER,
268+
NPY_CASTING::NPY_SAFE_CASTING,
269+
self.opflags.as_mut_ptr(),
270+
ptr::null_mut(),
271+
)
272+
};
273+
let py = self.arrays[0].py();
274+
NpyMultiIterArray::new(iter_ptr, py).ok_or_else(|| PyErr::fetch(py))
275+
}
276+
}
277+
278+
pub struct NpyMultiIterArray<'py, T, S: MultiIterModeHasManyArrays> {
279+
iterator: ptr::NonNull<objects::NpyIter>,
280+
iternext: unsafe extern "C" fn(*mut objects::NpyIter) -> c_int,
281+
empty: bool,
282+
dataptr: *mut *mut c_char,
283+
284+
return_type: PhantomData<T>,
285+
structure: PhantomData<S>,
286+
_py: Python<'py>,
287+
}
288+
289+
impl<'py, T, S: MultiIterModeHasManyArrays> NpyMultiIterArray<'py, T, S> {
290+
fn new(iterator: *mut objects::NpyIter, py: Python<'py>) -> Option<Self> {
291+
let mut iterator = ptr::NonNull::new(iterator)?;
292+
293+
// TODO replace the null second arg with something correct.
294+
let iternext =
295+
unsafe { PY_ARRAY_API.NpyIter_GetIterNext(iterator.as_mut(), ptr::null_mut())? };
296+
let dataptr = unsafe { PY_ARRAY_API.NpyIter_GetDataPtrArray(iterator.as_mut()) };
297+
298+
if dataptr.is_null() {
299+
unsafe { PY_ARRAY_API.NpyIter_Deallocate(iterator.as_mut()) };
300+
}
301+
302+
Some(Self {
303+
iterator,
304+
iternext,
305+
empty: false, // TODO: Handle empty iterators
306+
dataptr,
307+
return_type: PhantomData,
308+
structure: PhantomData,
309+
_py: py,
310+
})
311+
}
312+
}
313+
314+
impl<'py, T, S: MultiIterModeHasManyArrays> Drop for NpyMultiIterArray<'py, T, S> {
315+
fn drop(&mut self) {
316+
let _success = unsafe { PY_ARRAY_API.NpyIter_Deallocate(self.iterator.as_mut()) };
317+
// TODO: Handle _success somehow?
318+
}
319+
}
320+
321+
impl<'py, T: 'py> std::iter::Iterator for NpyMultiIterArray<'py, T, RO<RO<()>>> {
322+
type Item = (&'py T, &'py T);
323+
324+
fn next(&mut self) -> Option<Self::Item> {
325+
if self.empty {
326+
None
327+
} else {
328+
// Note: This pointer is correct and doesn't need to be updated,
329+
// note that we're derefencing a **char into a *char casting to a *T
330+
// and then transforming that into a reference, the value that dataptr
331+
// points to is being updated by iternext to point to the next value.
332+
let retval = Some(unsafe {
333+
(
334+
&*(*self.dataptr as *mut T),
335+
&*(*self.dataptr.offset(1) as *mut T),
336+
)
337+
});
338+
self.empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
339+
retval
340+
}
341+
}
342+
}
343+
344+
impl<'py, T: 'py> std::iter::Iterator for NpyMultiIterArray<'py, T, RO<RW<()>>> {
345+
type Item = (&'py mut T, &'py T);
346+
347+
fn next(&mut self) -> Option<Self::Item> {
348+
if self.empty {
349+
None
350+
} else {
351+
// Note: This pointer is correct and doesn't need to be updated,
352+
// note that we're derefencing a **char into a *char casting to a *T
353+
// and then transforming that into a reference, the value that dataptr
354+
// points to is being updated by iternext to point to the next value.
355+
let retval = Some(unsafe {
356+
(
357+
&mut *(*self.dataptr as *mut T),
358+
&*(*self.dataptr.offset(1) as *mut T),
359+
)
360+
});
361+
self.empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
362+
retval
363+
}
364+
}
365+
}
366+
367+
impl<'py, T: 'py> std::iter::Iterator for NpyMultiIterArray<'py, T, RW<RO<()>>> {
368+
type Item = (&'py T, &'py mut T);
369+
370+
fn next(&mut self) -> Option<Self::Item> {
371+
if self.empty {
372+
None
373+
} else {
374+
// Note: This pointer is correct and doesn't need to be updated,
375+
// note that we're derefencing a **char into a *char casting to a *T
376+
// and then transforming that into a reference, the value that dataptr
377+
// points to is being updated by iternext to point to the next value.
378+
let retval = Some(unsafe {
379+
(
380+
&*(*self.dataptr as *mut T),
381+
&mut *(*self.dataptr.offset(1) as *mut T),
382+
)
383+
});
384+
self.empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
385+
retval
386+
}
387+
}
388+
}
389+
390+
impl<'py, T: 'py> std::iter::Iterator for NpyMultiIterArray<'py, T, RW<RW<()>>> {
391+
type Item = (&'py mut T, &'py mut T);
392+
393+
fn next(&mut self) -> Option<Self::Item> {
394+
if self.empty {
395+
None
396+
} else {
397+
// Note: This pointer is correct and doesn't need to be updated,
398+
// note that we're derefencing a **char into a *char casting to a *T
399+
// and then transforming that into a reference, the value that dataptr
400+
// points to is being updated by iternext to point to the next value.
401+
let retval = Some(unsafe {
402+
(
403+
&mut *(*self.dataptr as *mut T),
404+
&mut *(*self.dataptr.offset(1) as *mut T),
405+
)
406+
});
407+
self.empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
408+
retval
409+
}
410+
}
411+
}

tests/iter.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ fn get_iter() {
77
let dim = (3, 5);
88
let arr = PyArray::<f64, _>::zeros(gil.python(), dim, false);
99
let mut iter = npyiter::NpyIterBuilder::new(arr)
10-
.add(NpyIterFlag::ReadOnly)
10+
.set(NpyIterFlag::ReadOnly)
1111
.build()
1212
.map_err(|e| e.print(gil.python()))
1313
.unwrap();
@@ -21,7 +21,7 @@ fn sum_iter() -> PyResult<()> {
2121

2222
let arr = PyArray::from_vec2(gil.python(), &vec_data)?;
2323
let iter = npyiter::NpyIterBuilder::new(arr)
24-
.add(NpyIterFlag::ReadOnly)
24+
.set(NpyIterFlag::ReadOnly)
2525
.build()
2626
.map_err(|e| e.print(gil.python()))
2727
.unwrap();
@@ -31,3 +31,28 @@ fn sum_iter() -> PyResult<()> {
3131
assert_eq!(iter.sum::<f64>(), 15.0);
3232
Ok(())
3333
}
34+
35+
#[test]
36+
fn multi_iter() -> PyResult<()> {
37+
let gil = pyo3::Python::acquire_gil();
38+
let vec_data1 = vec![vec![0.0, 1.0], vec![2.0, 3.0], vec![4.0, 5.0]];
39+
let vec_data2 = vec![vec![6.0, 7.0], vec![8.0, 9.0], vec![10.0, 11.0]];
40+
41+
let arr1 = PyArray::from_vec2(gil.python(), &vec_data1)?;
42+
let arr2 = PyArray::from_vec2(gil.python(), &vec_data2)?;
43+
44+
let iter = npyiter::NpyMultiIterBuilder::new()
45+
.add_readonly_array(arr1)
46+
.add_readonly_array(arr2)
47+
.build()
48+
.map_err(|e| e.print(gil.python()))
49+
.unwrap();
50+
51+
let mut sum = 0.0;
52+
for (x, y) in iter {
53+
sum += *x * *y;
54+
}
55+
56+
assert_eq!(sum, 145.0);
57+
Ok(())
58+
}

0 commit comments

Comments
 (0)