@@ -4,7 +4,7 @@ use crate::npyffi::{
4
4
types:: { NPY_CASTING , NPY_ORDER } ,
5
5
* ,
6
6
} ;
7
- use crate :: types:: TypeNum ;
7
+ use crate :: types:: Element ;
8
8
use pyo3:: { prelude:: * , PyNativeType } ;
9
9
10
10
use std:: marker:: PhantomData ;
@@ -13,10 +13,11 @@ use std::ptr;
13
13
14
14
#[ derive( Clone , Copy , Debug , Eq , PartialEq ) ]
15
15
pub enum NpyIterFlag {
16
- CIndex ,
16
+ /* CIndex,
17
17
FIndex,
18
- MultiIndex ,
19
- ExternalLoop ,
18
+ MultiIndex, */
19
+ // ExternalLoop, // This flag greatly modifies the behaviour of accessing the data
20
+ // so we don't support it.
20
21
CommonDtype ,
21
22
RefsOk ,
22
23
ZerosizeOk ,
@@ -27,19 +28,19 @@ pub enum NpyIterFlag {
27
28
DelayBufAlloc ,
28
29
DontNegateStrides ,
29
30
CopyIfOverlap ,
30
- ReadWrite ,
31
+ /* ReadWrite,
31
32
ReadOnly,
32
- WriteOnly ,
33
+ WriteOnly, */
33
34
}
34
35
35
36
impl NpyIterFlag {
36
37
fn to_c_enum ( & self ) -> npy_uint32 {
37
38
use NpyIterFlag :: * ;
38
39
match self {
39
- CIndex => NPY_ITER_C_INDEX ,
40
+ /* CIndex => NPY_ITER_C_INDEX,
40
41
FIndex => NPY_ITER_C_INDEX,
41
- MultiIndex => NPY_ITER_MULTI_INDEX ,
42
- ExternalLoop => NPY_ITER_EXTERNAL_LOOP ,
42
+ MultiIndex => NPY_ITER_MULTI_INDEX, */
43
+ /* ExternalLoop => NPY_ITER_EXTERNAL_LOOP, */
43
44
CommonDtype => NPY_ITER_COMMON_DTYPE ,
44
45
RefsOk => NPY_ITER_REFS_OK ,
45
46
ZerosizeOk => NPY_ITER_ZEROSIZE_OK ,
@@ -50,9 +51,9 @@ impl NpyIterFlag {
50
51
DelayBufAlloc => NPY_ITER_DELAY_BUFALLOC ,
51
52
DontNegateStrides => NPY_ITER_DONT_NEGATE_STRIDES ,
52
53
CopyIfOverlap => NPY_ITER_COPY_IF_OVERLAP ,
53
- ReadWrite => NPY_ITER_READWRITE ,
54
+ /* ReadWrite => NPY_ITER_READWRITE,
54
55
ReadOnly => NPY_ITER_READONLY,
55
- WriteOnly => NPY_ITER_WRITEONLY ,
56
+ WriteOnly => NPY_ITER_WRITEONLY, */
56
57
}
57
58
}
58
59
}
@@ -62,20 +63,22 @@ pub struct NpyIterBuilder<'py, T> {
62
63
array : & ' py PyArrayDyn < T > ,
63
64
}
64
65
65
- impl < ' py , T : TypeNum > NpyIterBuilder < ' py , T > {
66
- pub fn new < D : ndarray:: Dimension > ( array : & ' py PyArray < T , D > ) -> NpyIterBuilder < ' py , T > {
66
+ impl < ' py , T : Element > NpyIterBuilder < ' py , T > {
67
+ pub fn readwrite < D : ndarray:: Dimension > ( array : & ' py PyArray < T , D > ) -> NpyIterBuilder < ' py , T > {
67
68
NpyIterBuilder {
68
- flags : 0 ,
69
- array : array. into_dyn ( ) ,
69
+ flags : NPY_ITER_READWRITE ,
70
+ array : array. to_dyn ( ) ,
70
71
}
71
72
}
72
73
73
- pub fn set ( mut self , flag : NpyIterFlag ) -> Self {
74
- if flag == NpyIterFlag :: ExternalLoop {
75
- // TODO: I don't want to make set fallible, but also we don't want to
76
- // support ExternalLoop yet (maybe ever?).
77
- panic ! ( "rust-numpy does not currently support ExternalLoop access" ) ;
74
+ pub fn readonly < D : ndarray:: Dimension > ( array : & ' py PyArray < T , D > ) -> NpyIterBuilder < ' py , T > {
75
+ NpyIterBuilder {
76
+ flags : NPY_ITER_READONLY ,
77
+ array : array. to_dyn ( ) ,
78
78
}
79
+ }
80
+
81
+ pub fn set ( mut self , flag : NpyIterFlag ) -> Self {
79
82
self . flags |= flag. to_c_enum ( ) ;
80
83
self
81
84
}
@@ -191,7 +194,7 @@ pub struct NpyMultiIterBuilder<'py, T, S: MultiIterMode> {
191
194
structure : PhantomData < S > ,
192
195
}
193
196
194
- impl < ' py , T : TypeNum > NpyMultiIterBuilder < ' py , T , ( ) > {
197
+ impl < ' py , T : Element > NpyMultiIterBuilder < ' py , T , ( ) > {
195
198
pub fn new ( ) -> Self {
196
199
Self {
197
200
flags : 0 ,
@@ -202,11 +205,6 @@ impl<'py, T: TypeNum> NpyMultiIterBuilder<'py, T, ()> {
202
205
}
203
206
204
207
pub fn set ( mut self , flag : NpyIterFlag ) -> Self {
205
- if flag == NpyIterFlag :: ExternalLoop {
206
- // TODO: I don't want to make set fallible, but also we don't want to
207
- // support ExternalLoop yet (maybe ever?).
208
- panic ! ( "rust-numpy does not currently support ExternalLoop access" ) ;
209
- }
210
208
self . flags |= flag. to_c_enum ( ) ;
211
209
self
212
210
}
@@ -217,12 +215,12 @@ impl<'py, T: TypeNum> NpyMultiIterBuilder<'py, T, ()> {
217
215
}
218
216
}
219
217
220
- impl < ' py , T : TypeNum , S : MultiIterMode > NpyMultiIterBuilder < ' py , T , S > {
218
+ impl < ' py , T : Element , S : MultiIterMode > NpyMultiIterBuilder < ' py , T , S > {
221
219
pub fn add_readonly_array < D : ndarray:: Dimension > (
222
220
mut self ,
223
221
array : & ' py PyArray < T , D > ,
224
222
) -> NpyMultiIterBuilder < ' py , T , RO < S > > {
225
- self . arrays . push ( array. into_dyn ( ) ) ;
223
+ self . arrays . push ( array. to_dyn ( ) ) ;
226
224
self . opflags . push ( NPY_ITER_READONLY ) ;
227
225
228
226
NpyMultiIterBuilder {
@@ -237,7 +235,7 @@ impl<'py, T: TypeNum, S: MultiIterMode> NpyMultiIterBuilder<'py, T, S> {
237
235
mut self ,
238
236
array : & ' py PyArray < T , D > ,
239
237
) -> NpyMultiIterBuilder < ' py , T , RW < S > > {
240
- self . arrays . push ( array. into_dyn ( ) ) ;
238
+ self . arrays . push ( array. to_dyn ( ) ) ;
241
239
self . opflags . push ( NPY_ITER_READWRITE ) ;
242
240
243
241
NpyMultiIterBuilder {
@@ -249,7 +247,7 @@ impl<'py, T: TypeNum, S: MultiIterMode> NpyMultiIterBuilder<'py, T, S> {
249
247
}
250
248
}
251
249
252
- impl < ' py , T : TypeNum , S : MultiIterModeHasManyArrays > NpyMultiIterBuilder < ' py , T , S > {
250
+ impl < ' py , T : Element , S : MultiIterModeHasManyArrays > NpyMultiIterBuilder < ' py , T , S > {
253
251
pub fn build ( mut self ) -> PyResult < NpyMultiIterArray < ' py , T , S > > {
254
252
assert ! ( self . arrays. len( ) == self . opflags. len( ) ) ;
255
253
assert ! ( self . arrays. len( ) <= i32 :: MAX as usize ) ;
@@ -279,6 +277,7 @@ pub struct NpyMultiIterArray<'py, T, S: MultiIterModeHasManyArrays> {
279
277
iterator : ptr:: NonNull < objects:: NpyIter > ,
280
278
iternext : unsafe extern "C" fn ( * mut objects:: NpyIter ) -> c_int ,
281
279
empty : bool ,
280
+ iter_size : npy_intp ,
282
281
dataptr : * mut * mut c_char ,
283
282
284
283
return_type : PhantomData < T > ,
@@ -298,11 +297,14 @@ impl<'py, T, S: MultiIterModeHasManyArrays> NpyMultiIterArray<'py, T, S> {
298
297
if dataptr. is_null ( ) {
299
298
unsafe { PY_ARRAY_API . NpyIter_Deallocate ( iterator. as_mut ( ) ) } ;
300
299
}
300
+
301
+ let iter_size = unsafe { PY_ARRAY_API . NpyIter_GetIterSize ( iterator. as_mut ( ) ) } ;
301
302
302
303
Some ( Self {
303
304
iterator,
304
305
iternext,
305
- empty : false , // TODO: Handle empty iterators
306
+ iter_size,
307
+ empty : iter_size != 0 , // TODO: Handle empty iterators
306
308
dataptr,
307
309
return_type : PhantomData ,
308
310
structure : PhantomData ,
@@ -339,6 +341,10 @@ impl<'py, T: 'py> std::iter::Iterator for NpyMultiIterArray<'py, T, $arg> {
339
341
retval
340
342
}
341
343
}
344
+
345
+ fn size_hint( & self ) -> ( usize , Option <usize >) {
346
+ ( self . iter_size as usize , Some ( self . iter_size as usize ) )
347
+ }
342
348
}
343
349
}
344
350
}
0 commit comments