Skip to content

Commit 2132036

Browse files
committed
Adds some more comments and a new test.
1 parent c79f220 commit 2132036

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

src/npyiter.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,17 @@ impl<'py, T: TypeNum> NpyIterBuilder<'py, T> {
7070
}
7171
}
7272

73-
pub fn add(mut self, flag: NpyIterFlag) -> Self {
73+
pub fn set(mut self, flag: NpyIterFlag) -> Self {
74+
if flag == NpyIterFlag::ExternalLoop {
75+
// TODO: I don't want to make set fallible, but also we don't want to
76+
// support ExternalLoop yet (maybe ever?).
77+
panic!("rust-numpy does not currently support ExternalLoop access");
78+
}
7479
self.flags |= flag.to_c_enum();
7580
self
7681
}
7782

78-
pub fn remove(mut self, flag: NpyIterFlag) -> Self {
83+
pub fn unset(mut self, flag: NpyIterFlag) -> Self {
7984
self.flags &= !flag.to_c_enum();
8085
self
8186
}
@@ -143,6 +148,10 @@ impl<'py, T: 'py> std::iter::Iterator for NpyIterSingleArray<'py, T> {
143148
if self.empty {
144149
None
145150
} else {
151+
// Note: This pointer is correct and doesn't need to be updated,
152+
// note that we're derefencing a **char into a *char casting to a *T
153+
// and then transforming that into a reference, the value that dataptr
154+
// points to is being updated by iternext to point to the next value.
146155
let retval = Some(unsafe { &*(*self.dataptr as *mut T) });
147156
self.empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0;
148157
retval

tests/iter.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use numpy::{npyiter::NpyIterFlag, *};
2+
use pyo3::PyResult;
23

34
#[test]
45
fn get_iter() {
@@ -12,3 +13,21 @@ fn get_iter() {
1213
.unwrap();
1314
assert_eq!(*iter.next().unwrap(), 0.0);
1415
}
16+
17+
#[test]
18+
fn sum_iter() -> PyResult<()> {
19+
let gil = pyo3::Python::acquire_gil();
20+
let vec_data = vec![vec![0.0, 1.0], vec![2.0, 3.0], vec![4.0, 5.0]];
21+
22+
let arr = PyArray::from_vec2(gil.python(), &vec_data)?;
23+
let iter = npyiter::NpyIterBuilder::new(arr)
24+
.add(NpyIterFlag::ReadOnly)
25+
.build()
26+
.map_err(|e| e.print(gil.python()))
27+
.unwrap();
28+
29+
// The order of iteration is not specified, so we should restrict ourselves
30+
// to tests that don't verify a given order.
31+
assert_eq!(iter.sum::<f64>(), 15.0);
32+
Ok(())
33+
}

0 commit comments

Comments
 (0)