Skip to content

Commit 0e6a618

Browse files
committed
Cleans up flag handling for NpyIter. Removes a bunch of unsafe enum flags.
1 parent ef40c03 commit 0e6a618

File tree

2 files changed

+43
-39
lines changed

2 files changed

+43
-39
lines changed

src/npyiter.rs

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::npyffi::{
44
types::{NPY_CASTING, NPY_ORDER},
55
*,
66
};
7-
use crate::types::TypeNum;
7+
use crate::types::Element;
88
use pyo3::{prelude::*, PyNativeType};
99

1010
use std::marker::PhantomData;
@@ -13,10 +13,11 @@ use std::ptr;
1313

1414
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1515
pub enum NpyIterFlag {
16-
CIndex,
16+
/* CIndex,
1717
FIndex,
18-
MultiIndex,
19-
ExternalLoop,
18+
MultiIndex, */
19+
// ExternalLoop, // This flag greatly modifies the behaviour of accessing the data
20+
// so we don't support it.
2021
CommonDtype,
2122
RefsOk,
2223
ZerosizeOk,
@@ -27,19 +28,19 @@ pub enum NpyIterFlag {
2728
DelayBufAlloc,
2829
DontNegateStrides,
2930
CopyIfOverlap,
30-
ReadWrite,
31+
/* ReadWrite,
3132
ReadOnly,
32-
WriteOnly,
33+
WriteOnly, */
3334
}
3435

3536
impl NpyIterFlag {
3637
fn to_c_enum(&self) -> npy_uint32 {
3738
use NpyIterFlag::*;
3839
match self {
39-
CIndex => NPY_ITER_C_INDEX,
40+
/* CIndex => NPY_ITER_C_INDEX,
4041
FIndex => NPY_ITER_C_INDEX,
41-
MultiIndex => NPY_ITER_MULTI_INDEX,
42-
ExternalLoop => NPY_ITER_EXTERNAL_LOOP,
42+
MultiIndex => NPY_ITER_MULTI_INDEX, */
43+
/* ExternalLoop => NPY_ITER_EXTERNAL_LOOP, */
4344
CommonDtype => NPY_ITER_COMMON_DTYPE,
4445
RefsOk => NPY_ITER_REFS_OK,
4546
ZerosizeOk => NPY_ITER_ZEROSIZE_OK,
@@ -50,9 +51,9 @@ impl NpyIterFlag {
5051
DelayBufAlloc => NPY_ITER_DELAY_BUFALLOC,
5152
DontNegateStrides => NPY_ITER_DONT_NEGATE_STRIDES,
5253
CopyIfOverlap => NPY_ITER_COPY_IF_OVERLAP,
53-
ReadWrite => NPY_ITER_READWRITE,
54+
/* ReadWrite => NPY_ITER_READWRITE,
5455
ReadOnly => NPY_ITER_READONLY,
55-
WriteOnly => NPY_ITER_WRITEONLY,
56+
WriteOnly => NPY_ITER_WRITEONLY, */
5657
}
5758
}
5859
}
@@ -62,20 +63,22 @@ pub struct NpyIterBuilder<'py, T> {
6263
array: &'py PyArrayDyn<T>,
6364
}
6465

65-
impl<'py, T: TypeNum> NpyIterBuilder<'py, T> {
66-
pub fn new<D: ndarray::Dimension>(array: &'py PyArray<T, D>) -> NpyIterBuilder<'py, T> {
66+
impl<'py, T: Element> NpyIterBuilder<'py, T> {
67+
pub fn readwrite<D: ndarray::Dimension>(array: &'py PyArray<T, D>) -> NpyIterBuilder<'py, T> {
6768
NpyIterBuilder {
68-
flags: 0,
69-
array: array.into_dyn(),
69+
flags: NPY_ITER_READWRITE,
70+
array: array.to_dyn(),
7071
}
7172
}
7273

73-
pub fn set(mut self, flag: NpyIterFlag) -> Self {
74-
if flag == NpyIterFlag::ExternalLoop {
75-
// TODO: I don't want to make set fallible, but also we don't want to
76-
// support ExternalLoop yet (maybe ever?).
77-
panic!("rust-numpy does not currently support ExternalLoop access");
74+
pub fn readonly<D: ndarray::Dimension>(array: &'py PyArray<T, D>) -> NpyIterBuilder<'py, T> {
75+
NpyIterBuilder {
76+
flags: NPY_ITER_READONLY,
77+
array: array.to_dyn(),
7878
}
79+
}
80+
81+
pub fn set(mut self, flag: NpyIterFlag) -> Self {
7982
self.flags |= flag.to_c_enum();
8083
self
8184
}
@@ -191,7 +194,7 @@ pub struct NpyMultiIterBuilder<'py, T, S: MultiIterMode> {
191194
structure: PhantomData<S>,
192195
}
193196

194-
impl<'py, T: TypeNum> NpyMultiIterBuilder<'py, T, ()> {
197+
impl<'py, T: Element> NpyMultiIterBuilder<'py, T, ()> {
195198
pub fn new() -> Self {
196199
Self {
197200
flags: 0,
@@ -202,11 +205,6 @@ impl<'py, T: TypeNum> NpyMultiIterBuilder<'py, T, ()> {
202205
}
203206

204207
pub fn set(mut self, flag: NpyIterFlag) -> Self {
205-
if flag == NpyIterFlag::ExternalLoop {
206-
// TODO: I don't want to make set fallible, but also we don't want to
207-
// support ExternalLoop yet (maybe ever?).
208-
panic!("rust-numpy does not currently support ExternalLoop access");
209-
}
210208
self.flags |= flag.to_c_enum();
211209
self
212210
}
@@ -217,12 +215,12 @@ impl<'py, T: TypeNum> NpyMultiIterBuilder<'py, T, ()> {
217215
}
218216
}
219217

220-
impl<'py, T: TypeNum, S: MultiIterMode> NpyMultiIterBuilder<'py, T, S> {
218+
impl<'py, T: Element, S: MultiIterMode> NpyMultiIterBuilder<'py, T, S> {
221219
pub fn add_readonly_array<D: ndarray::Dimension>(
222220
mut self,
223221
array: &'py PyArray<T, D>,
224222
) -> NpyMultiIterBuilder<'py, T, RO<S>> {
225-
self.arrays.push(array.into_dyn());
223+
self.arrays.push(array.to_dyn());
226224
self.opflags.push(NPY_ITER_READONLY);
227225

228226
NpyMultiIterBuilder {
@@ -237,7 +235,7 @@ impl<'py, T: TypeNum, S: MultiIterMode> NpyMultiIterBuilder<'py, T, S> {
237235
mut self,
238236
array: &'py PyArray<T, D>,
239237
) -> NpyMultiIterBuilder<'py, T, RW<S>> {
240-
self.arrays.push(array.into_dyn());
238+
self.arrays.push(array.to_dyn());
241239
self.opflags.push(NPY_ITER_READWRITE);
242240

243241
NpyMultiIterBuilder {
@@ -249,7 +247,7 @@ impl<'py, T: TypeNum, S: MultiIterMode> NpyMultiIterBuilder<'py, T, S> {
249247
}
250248
}
251249

252-
impl<'py, T: TypeNum, S: MultiIterModeHasManyArrays> NpyMultiIterBuilder<'py, T, S> {
250+
impl<'py, T: Element, S: MultiIterModeHasManyArrays> NpyMultiIterBuilder<'py, T, S> {
253251
pub fn build(mut self) -> PyResult<NpyMultiIterArray<'py, T, S>> {
254252
assert!(self.arrays.len() == self.opflags.len());
255253
assert!(self.arrays.len() <= i32::MAX as usize);
@@ -279,6 +277,7 @@ pub struct NpyMultiIterArray<'py, T, S: MultiIterModeHasManyArrays> {
279277
iterator: ptr::NonNull<objects::NpyIter>,
280278
iternext: unsafe extern "C" fn(*mut objects::NpyIter) -> c_int,
281279
empty: bool,
280+
iter_size: npy_intp,
282281
dataptr: *mut *mut c_char,
283282

284283
return_type: PhantomData<T>,
@@ -298,11 +297,14 @@ impl<'py, T, S: MultiIterModeHasManyArrays> NpyMultiIterArray<'py, T, S> {
298297
if dataptr.is_null() {
299298
unsafe { PY_ARRAY_API.NpyIter_Deallocate(iterator.as_mut()) };
300299
}
300+
301+
let iter_size = unsafe { PY_ARRAY_API.NpyIter_GetIterSize(iterator.as_mut()) };
301302

302303
Some(Self {
303304
iterator,
304305
iternext,
305-
empty: false, // TODO: Handle empty iterators
306+
iter_size,
307+
empty: iter_size != 0, // TODO: Handle empty iterators
306308
dataptr,
307309
return_type: PhantomData,
308310
structure: PhantomData,
@@ -339,6 +341,10 @@ impl<'py, T: 'py> std::iter::Iterator for NpyMultiIterArray<'py, T, $arg> {
339341
retval
340342
}
341343
}
344+
345+
fn size_hint(&self) -> (usize, Option<usize>) {
346+
(self.iter_size as usize, Some(self.iter_size as usize))
347+
}
342348
}
343349
}
344350
}

tests/iter.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ fn get_iter() {
66
let gil = pyo3::Python::acquire_gil();
77
let dim = (3, 5);
88
let arr = PyArray::<f64, _>::zeros(gil.python(), dim, false);
9-
let mut iter = npyiter::NpyIterBuilder::new(arr)
10-
.set(NpyIterFlag::ReadOnly)
9+
let mut iter = npyiter::NpyIterBuilder::readonly(arr)
1110
.build()
1211
.map_err(|e| e.print(gil.python()))
1312
.unwrap();
@@ -20,8 +19,7 @@ fn sum_iter() -> PyResult<()> {
2019
let vec_data = vec![vec![0.0, 1.0], vec![2.0, 3.0], vec![4.0, 5.0]];
2120

2221
let arr = PyArray::from_vec2(gil.python(), &vec_data)?;
23-
let iter = npyiter::NpyIterBuilder::new(arr)
24-
.set(NpyIterFlag::ReadOnly)
22+
let iter = npyiter::NpyIterBuilder::readonly(arr)
2523
.build()
2624
.map_err(|e| e.print(gil.python()))
2725
.unwrap();
@@ -38,14 +36,14 @@ fn multi_iter() -> PyResult<()> {
3836
let vec_data1 = vec![vec![0.0, 1.0], vec![2.0, 3.0], vec![4.0, 5.0]];
3937
let vec_data2 = vec![vec![6.0, 7.0], vec![8.0, 9.0], vec![10.0, 11.0]];
4038

41-
let arr1 = PyArray::from_vec2(gil.python(), &vec_data1)?;
42-
let arr2 = PyArray::from_vec2(gil.python(), &vec_data2)?;
43-
39+
let py = gil.python();
40+
let arr1 = PyArray::from_vec2(py, &vec_data1)?;
41+
let arr2 = PyArray::from_vec2(py, &vec_data2)?;
4442
let iter = npyiter::NpyMultiIterBuilder::new()
4543
.add_readonly_array(arr1)
4644
.add_readonly_array(arr2)
4745
.build()
48-
.map_err(|e| e.print(gil.python()))
46+
.map_err(|e| e.print(py))
4947
.unwrap();
5048

5149
let mut sum = 0.0;

0 commit comments

Comments
 (0)