Skip to content

Commit a70446c

Browse files
committed
Resyncs the NpySingleIter and NpyMultiIter implmentations.
1 parent 8de9fd9 commit a70446c

File tree

1 file changed

+29
-13
lines changed

1 file changed

+29
-13
lines changed

src/npyiter.rs

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,21 +109,21 @@ pub struct NpySingleIter<'py, T> {
109109
iterator: ptr::NonNull<objects::NpyIter>,
110110
iternext: unsafe extern "C" fn(*mut objects::NpyIter) -> c_int,
111111
empty: bool,
112+
iter_size: npy_intp,
112113
dataptr: *mut *mut c_char,
113114
return_type: PhantomData<T>,
114115
_py: Python<'py>,
115116
}
116117

117118
impl<'py, T> NpySingleIter<'py, T> {
118-
fn new(iterator: *mut objects::NpyIter, py: Python<'py>) -> PyResult<NpySingleIter<'py, T>> {
119+
fn new(iterator: *mut objects::NpyIter, py: Python<'py>) -> PyResult<Self> {
119120
let mut iterator = match ptr::NonNull::new(iterator) {
120121
Some(iter) => iter,
121122
None => {
122123
return Err(NpyIterInstantiationError.into());
123124
}
124125
};
125126

126-
// TODO replace the null second arg with something correct.
127127
let iternext = match unsafe { PY_ARRAY_API.NpyIter_GetIterNext(iterator.as_mut(), ptr::null_mut()) } {
128128
Some(ptr) => ptr,
129129
None => {
@@ -137,10 +137,13 @@ impl<'py, T> NpySingleIter<'py, T> {
137137
return Err(NpyIterInstantiationError.into());
138138
}
139139

140-
Ok(NpySingleIter {
140+
let iter_size = unsafe { PY_ARRAY_API.NpyIter_GetIterSize(iterator.as_mut()) };
141+
142+
Ok(Self {
141143
iterator,
142144
iternext,
143-
empty: false, // TODO: Handle empty iterators
145+
iter_size,
146+
empty: iter_size == 0,
144147
dataptr,
145148
return_type: PhantomData,
146149
_py: py,
@@ -171,6 +174,10 @@ impl<'py, T: 'py> std::iter::Iterator for NpySingleIter<'py, T> {
171174
retval
172175
}
173176
}
177+
178+
fn size_hint(&self) -> (usize, Option<usize>) {
179+
(self.iter_size as usize, Some(self.iter_size as usize))
180+
}
174181
}
175182

176183
mod private {
@@ -189,7 +196,7 @@ macro_rules! private_impl {
189196
};
190197
}
191198

192-
/// A combinator type that represents an terator mode (e.g., ReadOnly + ReadWrite + ReadOnly).
199+
/// A combinator type that represents an terator mode (e.g., ReadOnly + ReadWrite).
193200
pub trait MultiIterMode {
194201
private_decl!();
195202
type Pre: MultiIterMode;
@@ -316,7 +323,7 @@ impl<'py, T: Element, S: MultiIterModeWithManyArrays> NpyMultiIterBuilder<'py, T
316323
)
317324
};
318325
let py = self.arrays[0].py();
319-
NpyMultiIter::new(iter_ptr, py).ok_or_else(|| PyErr::fetch(py))
326+
NpyMultiIter::new(iter_ptr, py)
320327
}
321328
}
322329

@@ -332,25 +339,34 @@ pub struct NpyMultiIter<'py, T, S: MultiIterModeWithManyArrays> {
332339
}
333340

334341
impl<'py, T, S: MultiIterModeWithManyArrays> NpyMultiIter<'py, T, S> {
335-
fn new(iterator: *mut objects::NpyIter, py: Python<'py>) -> Option<Self> {
336-
let mut iterator = ptr::NonNull::new(iterator)?;
342+
fn new(iterator: *mut objects::NpyIter, py: Python<'py>) -> PyResult<Self> {
343+
let mut iterator = match ptr::NonNull::new(iterator) {
344+
Some(ptr) => ptr,
345+
None => {
346+
return Err(NpyIterInstantiationError.into());
347+
}
348+
};
337349

338-
// TODO replace the null second arg with something correct.
339-
let iternext =
340-
unsafe { PY_ARRAY_API.NpyIter_GetIterNext(iterator.as_mut(), ptr::null_mut())? };
350+
let iternext = match unsafe { PY_ARRAY_API.NpyIter_GetIterNext(iterator.as_mut(), ptr::null_mut()) } {
351+
Some(ptr) => ptr,
352+
None => {
353+
return Err(PyErr::fetch(py));
354+
}
355+
};
341356
let dataptr = unsafe { PY_ARRAY_API.NpyIter_GetDataPtrArray(iterator.as_mut()) };
342357

343358
if dataptr.is_null() {
344359
unsafe { PY_ARRAY_API.NpyIter_Deallocate(iterator.as_mut()) };
360+
return Err(NpyIterInstantiationError.into());
345361
}
346362

347363
let iter_size = unsafe { PY_ARRAY_API.NpyIter_GetIterSize(iterator.as_mut()) };
348364

349-
Some(Self {
365+
Ok(Self {
350366
iterator,
351367
iternext,
352368
iter_size,
353-
empty: iter_size == 0, // TODO: Handle empty iterators
369+
empty: iter_size == 0,
354370
dataptr,
355371
marker: PhantomData,
356372
_py: py,

0 commit comments

Comments
 (0)