@@ -109,21 +109,21 @@ pub struct NpySingleIter<'py, T> {
109
109
iterator : ptr:: NonNull < objects:: NpyIter > ,
110
110
iternext : unsafe extern "C" fn ( * mut objects:: NpyIter ) -> c_int ,
111
111
empty : bool ,
112
+ iter_size : npy_intp ,
112
113
dataptr : * mut * mut c_char ,
113
114
return_type : PhantomData < T > ,
114
115
_py : Python < ' py > ,
115
116
}
116
117
117
118
impl < ' py , T > NpySingleIter < ' py , T > {
118
- fn new ( iterator : * mut objects:: NpyIter , py : Python < ' py > ) -> PyResult < NpySingleIter < ' py , T > > {
119
+ fn new ( iterator : * mut objects:: NpyIter , py : Python < ' py > ) -> PyResult < Self > {
119
120
let mut iterator = match ptr:: NonNull :: new ( iterator) {
120
121
Some ( iter) => iter,
121
122
None => {
122
123
return Err ( NpyIterInstantiationError . into ( ) ) ;
123
124
}
124
125
} ;
125
126
126
- // TODO replace the null second arg with something correct.
127
127
let iternext = match unsafe { PY_ARRAY_API . NpyIter_GetIterNext ( iterator. as_mut ( ) , ptr:: null_mut ( ) ) } {
128
128
Some ( ptr) => ptr,
129
129
None => {
@@ -137,10 +137,13 @@ impl<'py, T> NpySingleIter<'py, T> {
137
137
return Err ( NpyIterInstantiationError . into ( ) ) ;
138
138
}
139
139
140
- Ok ( NpySingleIter {
140
+ let iter_size = unsafe { PY_ARRAY_API . NpyIter_GetIterSize ( iterator. as_mut ( ) ) } ;
141
+
142
+ Ok ( Self {
141
143
iterator,
142
144
iternext,
143
- empty : false , // TODO: Handle empty iterators
145
+ iter_size,
146
+ empty : iter_size == 0 ,
144
147
dataptr,
145
148
return_type : PhantomData ,
146
149
_py : py,
@@ -171,6 +174,10 @@ impl<'py, T: 'py> std::iter::Iterator for NpySingleIter<'py, T> {
171
174
retval
172
175
}
173
176
}
177
+
178
+ fn size_hint ( & self ) -> ( usize , Option < usize > ) {
179
+ ( self . iter_size as usize , Some ( self . iter_size as usize ) )
180
+ }
174
181
}
175
182
176
183
mod private {
@@ -189,7 +196,7 @@ macro_rules! private_impl {
189
196
} ;
190
197
}
191
198
192
- /// A combinator type that represents an terator mode (e.g., ReadOnly + ReadWrite + ReadOnly ).
199
+ /// A combinator type that represents an terator mode (e.g., ReadOnly + ReadWrite).
193
200
pub trait MultiIterMode {
194
201
private_decl ! ( ) ;
195
202
type Pre : MultiIterMode ;
@@ -316,7 +323,7 @@ impl<'py, T: Element, S: MultiIterModeWithManyArrays> NpyMultiIterBuilder<'py, T
316
323
)
317
324
} ;
318
325
let py = self . arrays [ 0 ] . py ( ) ;
319
- NpyMultiIter :: new ( iter_ptr, py) . ok_or_else ( || PyErr :: fetch ( py ) )
326
+ NpyMultiIter :: new ( iter_ptr, py)
320
327
}
321
328
}
322
329
@@ -332,25 +339,34 @@ pub struct NpyMultiIter<'py, T, S: MultiIterModeWithManyArrays> {
332
339
}
333
340
334
341
impl < ' py , T , S : MultiIterModeWithManyArrays > NpyMultiIter < ' py , T , S > {
335
- fn new ( iterator : * mut objects:: NpyIter , py : Python < ' py > ) -> Option < Self > {
336
- let mut iterator = ptr:: NonNull :: new ( iterator) ?;
342
+ fn new ( iterator : * mut objects:: NpyIter , py : Python < ' py > ) -> PyResult < Self > {
343
+ let mut iterator = match ptr:: NonNull :: new ( iterator) {
344
+ Some ( ptr) => ptr,
345
+ None => {
346
+ return Err ( NpyIterInstantiationError . into ( ) ) ;
347
+ }
348
+ } ;
337
349
338
- // TODO replace the null second arg with something correct.
339
- let iternext =
340
- unsafe { PY_ARRAY_API . NpyIter_GetIterNext ( iterator. as_mut ( ) , ptr:: null_mut ( ) ) ? } ;
350
+ let iternext = match unsafe { PY_ARRAY_API . NpyIter_GetIterNext ( iterator. as_mut ( ) , ptr:: null_mut ( ) ) } {
351
+ Some ( ptr) => ptr,
352
+ None => {
353
+ return Err ( PyErr :: fetch ( py) ) ;
354
+ }
355
+ } ;
341
356
let dataptr = unsafe { PY_ARRAY_API . NpyIter_GetDataPtrArray ( iterator. as_mut ( ) ) } ;
342
357
343
358
if dataptr. is_null ( ) {
344
359
unsafe { PY_ARRAY_API . NpyIter_Deallocate ( iterator. as_mut ( ) ) } ;
360
+ return Err ( NpyIterInstantiationError . into ( ) ) ;
345
361
}
346
362
347
363
let iter_size = unsafe { PY_ARRAY_API . NpyIter_GetIterSize ( iterator. as_mut ( ) ) } ;
348
364
349
- Some ( Self {
365
+ Ok ( Self {
350
366
iterator,
351
367
iternext,
352
368
iter_size,
353
- empty : iter_size == 0 , // TODO: Handle empty iterators
369
+ empty : iter_size == 0 ,
354
370
dataptr,
355
371
marker : PhantomData ,
356
372
_py : py,
0 commit comments