@@ -309,27 +309,42 @@ def concat(arrays, axis=0, dtype=None):
309
309
return result
310
310
311
311
312
- class LArrayIterator (object ):
313
- __slots__ = ('nextfunc' , 'axes' )
312
+ if PY2 :
313
+ class LArrayIterator (object ):
314
+ __slots__ = ('next' ,)
314
315
315
- def __init__ (self , array ):
316
- data_iter = iter (array .data )
317
- self .nextfunc = data_iter .next if PY2 else data_iter .__next__
318
- self .axes = array .axes [1 :]
316
+ def __init__ (self , array ):
317
+ data_iter = iter (array .data )
318
+ next_data_func = data_iter .next
319
+ res_axes = array .axes [1 :]
320
+ # this case should not happen (handled by the fastpath in LArray.__iter__)
321
+ assert len (res_axes ) > 0
319
322
320
- def __iter__ ( self ):
321
- return self
323
+ def next_func ( ):
324
+ return LArray ( next_data_func (), res_axes )
322
325
323
- def __next__ (self ):
324
- data = self .nextfunc ()
325
- axes = self .axes
326
- if len (axes ):
327
- return LArray (data , axes )
328
- else :
329
- return data
326
+ self .next = next_func
330
327
331
- # Python 2
332
- next = __next__
328
+ def __iter__ (self ):
329
+ return self
330
+ else :
331
+ class LArrayIterator (object ):
332
+ __slots__ = ('__next__' ,)
333
+
334
+ def __init__ (self , array ):
335
+ data_iter = iter (array .data )
336
+ next_data_func = data_iter .__next__
337
+ res_axes = array .axes [1 :]
338
+ # this case should not happen (handled by the fastpath in LArray.__iter__)
339
+ assert len (res_axes ) > 0
340
+
341
+ def next_func ():
342
+ return LArray (next_data_func (), res_axes )
343
+
344
+ self .__next__ = next_func
345
+
346
+ def __iter__ (self ):
347
+ return self
333
348
334
349
335
350
# TODO: rename to LArrayIndexIndexer or something like that
@@ -355,14 +370,41 @@ def _translate_key(self, key):
355
370
for axis_key , axis in zip (key , self .array .axes ))
356
371
357
372
def __getitem__ (self , key ):
358
- return self .array [self ._translate_key (key )]
373
+ ndim = self .array .ndim
374
+ full_scalar_key = (
375
+ (isinstance (key , (int , np .integer )) and ndim == 1 ) or
376
+ (isinstance (key , tuple ) and len (key ) == ndim and all (isinstance (k , (int , np .integer )) for k in key ))
377
+ )
378
+ # fast path when the result is a scalar
379
+ if full_scalar_key :
380
+ return self .array .data [key ]
381
+ else :
382
+ return self .array [self ._translate_key (key )]
359
383
360
384
def __setitem__ (self , key , value ):
361
- self .array [self ._translate_key (key )] = value
385
+ array = self .array
386
+ ndim = array .ndim
387
+ full_scalar_key = (
388
+ (isinstance (key , (int , np .integer )) and ndim == 1 ) or
389
+ (isinstance (key , tuple ) and len (key ) == ndim and all (isinstance (k , (int , np .integer )) for k in key ))
390
+ )
391
+ # fast path when setting a single cell
392
+ if full_scalar_key :
393
+ array .data [key ] = value
394
+ else :
395
+ array [self ._translate_key (key )] = value
362
396
363
397
def __len__ (self ):
364
398
return len (self .array )
365
399
400
+ def __iter__ (self ):
401
+ array = self .array
402
+ # fast path for 1D arrays (where we return scalars)
403
+ if array .ndim <= 1 :
404
+ return iter (array .data )
405
+ else :
406
+ return LArrayIterator (array )
407
+
366
408
367
409
class LArrayPointsIndexer (object ):
368
410
__slots__ = ('array' ,)
@@ -2696,6 +2738,7 @@ def _group_aggregate(self, op, items, keepaxes=False, out=None, **kwargs):
2696
2738
arr = np .asarray (arr )
2697
2739
op (arr , axis = axis_idx , out = out , ** kwargs )
2698
2740
del arr
2741
+
2699
2742
if killaxis :
2700
2743
assert group_idx [axis_idx ] == 0
2701
2744
res_data = res_data [idx ]
0 commit comments