Skip to content

Commit f38bbd8

Browse files
committed
Starts work on a safe npyiter interface.
1 parent 50ee9ef commit f38bbd8

File tree

2 files changed

+183
-0
lines changed

2 files changed

+183
-0
lines changed

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ pub mod array;
4040
pub mod convert;
4141
mod error;
4242
pub mod npyffi;
43+
pub mod npyiter;
4344
mod readonly;
4445
mod slice_box;
4546
mod types;

src/npyiter.rs

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
use crate::array::PyArray;
2+
use crate::npyffi;
3+
use crate::npyffi::array::PY_ARRAY_API;
4+
use crate::npyffi::objects;
5+
use crate::npyffi::types::{npy_uint32, NPY_CASTING, NPY_ORDER};
6+
use pyo3::prelude::*;
7+
8+
use std::marker::PhantomData;
9+
use std::os::raw::*;
10+
use std::ptr;
11+
12+
pub enum NPyIterFlag {
13+
CIndex,
14+
FIndex,
15+
MultiIndex,
16+
ExternalLoop,
17+
CommonDtype,
18+
RefsOk,
19+
ZerosizeOk,
20+
ReduceOk,
21+
Ranged,
22+
Buffered,
23+
GrowInner,
24+
DelayBufAlloc,
25+
DontNegateStrides,
26+
CopyIfOverlap,
27+
}
28+
29+
/*
30+
31+
#define NPY_ITER_C_INDEX 0x00000001
32+
#define NPY_ITER_F_INDEX 0x00000002
33+
#define NPY_ITER_MULTI_INDEX 0x00000004
34+
#define NPY_ITER_EXTERNAL_LOOP 0x00000008
35+
#define NPY_ITER_COMMON_DTYPE 0x00000010
36+
#define NPY_ITER_REFS_OK 0x00000020
37+
#define NPY_ITER_ZEROSIZE_OK 0x00000040
38+
#define NPY_ITER_REDUCE_OK 0x00000080
39+
#define NPY_ITER_RANGED 0x00000100
40+
#define NPY_ITER_BUFFERED 0x00000200
41+
#define NPY_ITER_GROWINNER 0x00000400
42+
#define NPY_ITER_DELAY_BUFALLOC 0x00000800
43+
#define NPY_ITER_DONT_NEGATE_STRIDES 0x00001000
44+
#define NPY_ITER_COPY_IF_OVERLAP 0x00002000
45+
#define NPY_ITER_READWRITE 0x00010000
46+
#define NPY_ITER_READONLY 0x00020000
47+
#define NPY_ITER_WRITEONLY 0x00040000
48+
#define NPY_ITER_NBO 0x00080000
49+
#define NPY_ITER_ALIGNED 0x00100000
50+
#define NPY_ITER_CONTIG 0x00200000
51+
#define NPY_ITER_COPY 0x00400000
52+
#define NPY_ITER_UPDATEIFCOPY 0x00800000
53+
#define NPY_ITER_ALLOCATE 0x01000000
54+
#define NPY_ITER_NO_SUBTYPE 0x02000000
55+
#define NPY_ITER_VIRTUAL 0x04000000
56+
#define NPY_ITER_NO_BROADCAST 0x08000000
57+
#define NPY_ITER_WRITEMASKED 0x10000000
58+
#define NPY_ITER_ARRAYMASK 0x20000000
59+
#define NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE 0x40000000
60+
61+
#define NPY_ITER_GLOBAL_FLAGS 0x0000ffff
62+
#define NPY_ITER_PER_OP_FLAGS 0xffff0000
63+
64+
*/
65+
66+
impl NPyIterFlag {
67+
fn to_c_enum(&self) -> npy_uint32 {
68+
use NPyIterFlag::*;
69+
match self {
70+
CIndex => 0x00000001,
71+
FIndex => 0x00000002,
72+
MultiIndex => 0x00000004,
73+
ExternalLoop => 0x00000008,
74+
CommonDtype => 0x00000010,
75+
RefsOk => 0x00000020,
76+
ZerosizeOk => 0x00000040,
77+
ReduceOk => 0x00000080,
78+
Ranged => 0x00000100,
79+
Buffered => 0x00000200,
80+
GrowInner => 0x00000400,
81+
DelayBufAlloc => 0x00000800,
82+
DontNegateStrides => 0x00001000,
83+
CopyIfOverlap => 0x00002000,
84+
}
85+
}
86+
}
87+
88+
pub struct NpyIterBuilder<'py, T> {
89+
flags: npy_uint32,
90+
array: *mut npyffi::PyArrayObject,
91+
py: Python<'py>,
92+
return_type: PhantomData<T>,
93+
}
94+
95+
impl<'py, T> NpyIterBuilder<'py, T> {
96+
pub fn new<D>(array: PyArray<T, D>, py: Python<'py>) -> NpyIterBuilder<'py, T> {
97+
NpyIterBuilder {
98+
array: array.as_array_ptr(),
99+
py,
100+
flags: 0,
101+
return_type: PhantomData,
102+
}
103+
}
104+
105+
pub fn set_iter_flags(&mut self, flag: NPyIterFlag, value: bool) -> &mut Self {
106+
if value {
107+
self.flags |= flag.to_c_enum();
108+
} else {
109+
self.flags &= !flag.to_c_enum();
110+
}
111+
self
112+
}
113+
114+
pub fn finish(self) -> Option<NpyIterSingleArray<'py, T>> {
115+
let iter_ptr = unsafe {
116+
PY_ARRAY_API.NpyIter_New(
117+
self.array,
118+
self.flags,
119+
NPY_ORDER::NPY_ANYORDER,
120+
NPY_CASTING::NPY_SAFE_CASTING,
121+
ptr::null_mut(),
122+
)
123+
};
124+
125+
NpyIterSingleArray::new(iter_ptr, self.py)
126+
}
127+
}
128+
129+
pub struct NpyIterSingleArray<'py, T> {
130+
iterator: ptr::NonNull<objects::NpyIter>,
131+
iternext: unsafe extern "C" fn(*mut objects::NpyIter) -> c_int,
132+
empty: bool,
133+
dataptr: *mut *mut c_char,
134+
135+
return_type: PhantomData<T>,
136+
_py: Python<'py>,
137+
}
138+
139+
impl<'py, T> NpyIterSingleArray<'py, T> {
140+
fn new(iterator: *mut objects::NpyIter, py: Python<'py>) -> Option<NpyIterSingleArray<'py, T>> {
141+
let mut iterator = ptr::NonNull::new(iterator)?;
142+
143+
// TODO replace the null second arg with something correct.
144+
let iternext =
145+
unsafe { PY_ARRAY_API.NpyIter_GetIterNext(iterator.as_mut(), ptr::null_mut())? };
146+
let dataptr = unsafe { PY_ARRAY_API.NpyIter_GetDataPtrArray(iterator.as_mut()) };
147+
148+
if dataptr.is_null() {
149+
unsafe { PY_ARRAY_API.NpyIter_Deallocate(iterator.as_mut()) };
150+
}
151+
152+
Some(NpyIterSingleArray {
153+
iterator,
154+
iternext,
155+
empty: false, // TODO: Handle empty iterators
156+
dataptr,
157+
return_type: PhantomData,
158+
_py: py,
159+
})
160+
}
161+
}
162+
163+
impl<'py, T> Drop for NpyIterSingleArray<'py, T> {
164+
fn drop(&mut self) {
165+
let _success = unsafe { PY_ARRAY_API.NpyIter_Deallocate(self.iterator.as_mut()) };
166+
// TODO: Handle _success somehow?
167+
}
168+
}
169+
170+
impl<'py, T: 'py> std::iter::Iterator for NpyIterSingleArray<'py, T> {
171+
type Item = &'py T;
172+
173+
fn next(&mut self) -> Option<Self::Item> {
174+
if self.empty {
175+
None
176+
} else {
177+
let retval = Some(unsafe { &*(*self.dataptr as *mut T) });
178+
self.empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
179+
retval
180+
}
181+
}
182+
}

0 commit comments

Comments
 (0)