@@ -158,3 +158,254 @@ impl<'py, T: 'py> std::iter::Iterator for NpyIterSingleArray<'py, T> {
158
158
}
159
159
}
160
160
}
161
+
162
+ pub trait MultiIterMode { }
163
+
164
+ impl MultiIterMode for ( ) { }
165
+
166
+ pub struct RO < S > {
167
+ structure : PhantomData < S > ,
168
+ }
169
+
170
+ impl < S : MultiIterMode > MultiIterMode for RO < S > { }
171
+
172
+ pub struct RW < S > {
173
+ structure : PhantomData < S > ,
174
+ }
175
+
176
+ impl < S : MultiIterMode > MultiIterMode for RW < S > { }
177
+
178
+ pub trait MultiIterModeHasManyArrays : MultiIterMode { }
179
+ impl MultiIterModeHasManyArrays for RO < RO < ( ) > > { }
180
+ impl MultiIterModeHasManyArrays for RO < RW < ( ) > > { }
181
+ impl MultiIterModeHasManyArrays for RW < RO < ( ) > > { }
182
+ impl MultiIterModeHasManyArrays for RW < RW < ( ) > > { }
183
+
184
+ impl < S : MultiIterModeHasManyArrays > MultiIterModeHasManyArrays for RO < S > { }
185
+ impl < S : MultiIterModeHasManyArrays > MultiIterModeHasManyArrays for RW < S > { }
186
+
187
+ pub struct NpyMultiIterBuilder < ' py , T , S : MultiIterMode > {
188
+ flags : npy_uint32 ,
189
+ opflags : Vec < npy_uint32 > ,
190
+ arrays : Vec < & ' py PyArrayDyn < T > > ,
191
+ structure : PhantomData < S > ,
192
+ }
193
+
194
+ impl < ' py , T : TypeNum > NpyMultiIterBuilder < ' py , T , ( ) > {
195
+ pub fn new ( ) -> Self {
196
+ Self {
197
+ flags : 0 ,
198
+ opflags : Vec :: new ( ) ,
199
+ arrays : Vec :: new ( ) ,
200
+ structure : PhantomData ,
201
+ }
202
+ }
203
+
204
+ pub fn set ( mut self , flag : NpyIterFlag ) -> Self {
205
+ if flag == NpyIterFlag :: ExternalLoop {
206
+ // TODO: I don't want to make set fallible, but also we don't want to
207
+ // support ExternalLoop yet (maybe ever?).
208
+ panic ! ( "rust-numpy does not currently support ExternalLoop access" ) ;
209
+ }
210
+ self . flags |= flag. to_c_enum ( ) ;
211
+ self
212
+ }
213
+
214
+ pub fn unset ( mut self , flag : NpyIterFlag ) -> Self {
215
+ self . flags &= !flag. to_c_enum ( ) ;
216
+ self
217
+ }
218
+ }
219
+
220
+ impl < ' py , T : TypeNum , S : MultiIterMode > NpyMultiIterBuilder < ' py , T , S > {
221
+ pub fn add_readonly_array < D : ndarray:: Dimension > (
222
+ mut self ,
223
+ array : & ' py PyArray < T , D > ,
224
+ ) -> NpyMultiIterBuilder < ' py , T , RO < S > > {
225
+ self . arrays . push ( array. into_dyn ( ) ) ;
226
+ self . opflags . push ( NPY_ITER_READONLY ) ;
227
+
228
+ NpyMultiIterBuilder {
229
+ flags : self . flags ,
230
+ opflags : self . opflags ,
231
+ arrays : self . arrays ,
232
+ structure : PhantomData ,
233
+ }
234
+ }
235
+
236
+ pub fn add_readwrite_array < D : ndarray:: Dimension > (
237
+ mut self ,
238
+ array : & ' py PyArray < T , D > ,
239
+ ) -> NpyMultiIterBuilder < ' py , T , RW < S > > {
240
+ self . arrays . push ( array. into_dyn ( ) ) ;
241
+ self . opflags . push ( NPY_ITER_READWRITE ) ;
242
+
243
+ NpyMultiIterBuilder {
244
+ flags : self . flags ,
245
+ opflags : self . opflags ,
246
+ arrays : self . arrays ,
247
+ structure : PhantomData ,
248
+ }
249
+ }
250
+ }
251
+
252
+ impl < ' py , T : TypeNum , S : MultiIterModeHasManyArrays > NpyMultiIterBuilder < ' py , T , S > {
253
+ pub fn build ( mut self ) -> PyResult < NpyMultiIterArray < ' py , T , S > > {
254
+ assert ! ( self . arrays. len( ) == self . opflags. len( ) ) ;
255
+ assert ! ( self . arrays. len( ) <= i32 :: MAX as usize ) ;
256
+ assert ! ( 2 <= self . arrays. len( ) ) ;
257
+
258
+ let iter_ptr = unsafe {
259
+ PY_ARRAY_API . NpyIter_MultiNew (
260
+ self . arrays . len ( ) as i32 ,
261
+ self . arrays
262
+ . iter_mut ( )
263
+ . map ( |x| x. as_array_ptr ( ) )
264
+ . collect :: < Vec < _ > > ( )
265
+ . as_mut_ptr ( ) ,
266
+ self . flags ,
267
+ NPY_ORDER :: NPY_ANYORDER ,
268
+ NPY_CASTING :: NPY_SAFE_CASTING ,
269
+ self . opflags . as_mut_ptr ( ) ,
270
+ ptr:: null_mut ( ) ,
271
+ )
272
+ } ;
273
+ let py = self . arrays [ 0 ] . py ( ) ;
274
+ NpyMultiIterArray :: new ( iter_ptr, py) . ok_or_else ( || PyErr :: fetch ( py) )
275
+ }
276
+ }
277
+
278
+ pub struct NpyMultiIterArray < ' py , T , S : MultiIterModeHasManyArrays > {
279
+ iterator : ptr:: NonNull < objects:: NpyIter > ,
280
+ iternext : unsafe extern "C" fn ( * mut objects:: NpyIter ) -> c_int ,
281
+ empty : bool ,
282
+ dataptr : * mut * mut c_char ,
283
+
284
+ return_type : PhantomData < T > ,
285
+ structure : PhantomData < S > ,
286
+ _py : Python < ' py > ,
287
+ }
288
+
289
+ impl < ' py , T , S : MultiIterModeHasManyArrays > NpyMultiIterArray < ' py , T , S > {
290
+ fn new ( iterator : * mut objects:: NpyIter , py : Python < ' py > ) -> Option < Self > {
291
+ let mut iterator = ptr:: NonNull :: new ( iterator) ?;
292
+
293
+ // TODO replace the null second arg with something correct.
294
+ let iternext =
295
+ unsafe { PY_ARRAY_API . NpyIter_GetIterNext ( iterator. as_mut ( ) , ptr:: null_mut ( ) ) ? } ;
296
+ let dataptr = unsafe { PY_ARRAY_API . NpyIter_GetDataPtrArray ( iterator. as_mut ( ) ) } ;
297
+
298
+ if dataptr. is_null ( ) {
299
+ unsafe { PY_ARRAY_API . NpyIter_Deallocate ( iterator. as_mut ( ) ) } ;
300
+ }
301
+
302
+ Some ( Self {
303
+ iterator,
304
+ iternext,
305
+ empty : false , // TODO: Handle empty iterators
306
+ dataptr,
307
+ return_type : PhantomData ,
308
+ structure : PhantomData ,
309
+ _py : py,
310
+ } )
311
+ }
312
+ }
313
+
314
+ impl < ' py , T , S : MultiIterModeHasManyArrays > Drop for NpyMultiIterArray < ' py , T , S > {
315
+ fn drop ( & mut self ) {
316
+ let _success = unsafe { PY_ARRAY_API . NpyIter_Deallocate ( self . iterator . as_mut ( ) ) } ;
317
+ // TODO: Handle _success somehow?
318
+ }
319
+ }
320
+
321
+ impl < ' py , T : ' py > std:: iter:: Iterator for NpyMultiIterArray < ' py , T , RO < RO < ( ) > > > {
322
+ type Item = ( & ' py T , & ' py T ) ;
323
+
324
+ fn next ( & mut self ) -> Option < Self :: Item > {
325
+ if self . empty {
326
+ None
327
+ } else {
328
+ // Note: This pointer is correct and doesn't need to be updated,
329
+ // note that we're derefencing a **char into a *char casting to a *T
330
+ // and then transforming that into a reference, the value that dataptr
331
+ // points to is being updated by iternext to point to the next value.
332
+ let retval = Some ( unsafe {
333
+ (
334
+ & * ( * self . dataptr as * mut T ) ,
335
+ & * ( * self . dataptr . offset ( 1 ) as * mut T ) ,
336
+ )
337
+ } ) ;
338
+ self . empty = unsafe { ( self . iternext ) ( self . iterator . as_mut ( ) ) } == 0 ;
339
+ retval
340
+ }
341
+ }
342
+ }
343
+
344
+ impl < ' py , T : ' py > std:: iter:: Iterator for NpyMultiIterArray < ' py , T , RO < RW < ( ) > > > {
345
+ type Item = ( & ' py mut T , & ' py T ) ;
346
+
347
+ fn next ( & mut self ) -> Option < Self :: Item > {
348
+ if self . empty {
349
+ None
350
+ } else {
351
+ // Note: This pointer is correct and doesn't need to be updated,
352
+ // note that we're derefencing a **char into a *char casting to a *T
353
+ // and then transforming that into a reference, the value that dataptr
354
+ // points to is being updated by iternext to point to the next value.
355
+ let retval = Some ( unsafe {
356
+ (
357
+ & mut * ( * self . dataptr as * mut T ) ,
358
+ & * ( * self . dataptr . offset ( 1 ) as * mut T ) ,
359
+ )
360
+ } ) ;
361
+ self . empty = unsafe { ( self . iternext ) ( self . iterator . as_mut ( ) ) } == 0 ;
362
+ retval
363
+ }
364
+ }
365
+ }
366
+
367
+ impl < ' py , T : ' py > std:: iter:: Iterator for NpyMultiIterArray < ' py , T , RW < RO < ( ) > > > {
368
+ type Item = ( & ' py T , & ' py mut T ) ;
369
+
370
+ fn next ( & mut self ) -> Option < Self :: Item > {
371
+ if self . empty {
372
+ None
373
+ } else {
374
+ // Note: This pointer is correct and doesn't need to be updated,
375
+ // note that we're derefencing a **char into a *char casting to a *T
376
+ // and then transforming that into a reference, the value that dataptr
377
+ // points to is being updated by iternext to point to the next value.
378
+ let retval = Some ( unsafe {
379
+ (
380
+ & * ( * self . dataptr as * mut T ) ,
381
+ & mut * ( * self . dataptr . offset ( 1 ) as * mut T ) ,
382
+ )
383
+ } ) ;
384
+ self . empty = unsafe { ( self . iternext ) ( self . iterator . as_mut ( ) ) } == 0 ;
385
+ retval
386
+ }
387
+ }
388
+ }
389
+
390
+ impl < ' py , T : ' py > std:: iter:: Iterator for NpyMultiIterArray < ' py , T , RW < RW < ( ) > > > {
391
+ type Item = ( & ' py mut T , & ' py mut T ) ;
392
+
393
+ fn next ( & mut self ) -> Option < Self :: Item > {
394
+ if self . empty {
395
+ None
396
+ } else {
397
+ // Note: This pointer is correct and doesn't need to be updated,
398
+ // note that we're derefencing a **char into a *char casting to a *T
399
+ // and then transforming that into a reference, the value that dataptr
400
+ // points to is being updated by iternext to point to the next value.
401
+ let retval = Some ( unsafe {
402
+ (
403
+ & mut * ( * self . dataptr as * mut T ) ,
404
+ & mut * ( * self . dataptr . offset ( 1 ) as * mut T ) ,
405
+ )
406
+ } ) ;
407
+ self . empty = unsafe { ( self . iternext ) ( self . iterator . as_mut ( ) ) } == 0 ;
408
+ retval
409
+ }
410
+ }
411
+ }
0 commit comments