Skip to content

Commit 925e786

Browse files
committed
a few performance improvements
* a few percent improvements for iter(LArray) * a very large improvement for iter(LArray.i) which has an indirect impact on the performance of LArray.values and related methods/functions (LArray.items and zip_array*) * a large improvement for "LArray.i[key]" and "LArray.i[key] = value" when the result or target is a scalar
1 parent d6021a6 commit 925e786

File tree

2 files changed

+63
-20
lines changed

2 files changed

+63
-20
lines changed

doc/source/changes/version_0_31.rst.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ New features
2727
Miscellaneous improvements
2828
^^^^^^^^^^^^^^^^^^^^^^^^^^
2929

30-
* improved something.
30+
* improved the performance of a few LArray methods.
3131

3232

3333
Fixes

larray/core/array.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -309,27 +309,42 @@ def concat(arrays, axis=0, dtype=None):
309309
return result
310310

311311

312-
class LArrayIterator(object):
313-
__slots__ = ('nextfunc', 'axes')
312+
if PY2:
313+
class LArrayIterator(object):
314+
__slots__ = ('next',)
314315

315-
def __init__(self, array):
316-
data_iter = iter(array.data)
317-
self.nextfunc = data_iter.next if PY2 else data_iter.__next__
318-
self.axes = array.axes[1:]
316+
def __init__(self, array):
317+
data_iter = iter(array.data)
318+
next_data_func = data_iter.next
319+
res_axes = array.axes[1:]
320+
# this case should not happen (handled by the fastpath in LArray.__iter__)
321+
assert len(res_axes) > 0
319322

320-
def __iter__(self):
321-
return self
323+
def next_func():
324+
return LArray(next_data_func(), res_axes)
322325

323-
def __next__(self):
324-
data = self.nextfunc()
325-
axes = self.axes
326-
if len(axes):
327-
return LArray(data, axes)
328-
else:
329-
return data
326+
self.next = next_func
330327

331-
# Python 2
332-
next = __next__
328+
def __iter__(self):
329+
return self
330+
else:
331+
class LArrayIterator(object):
332+
__slots__ = ('__next__',)
333+
334+
def __init__(self, array):
335+
data_iter = iter(array.data)
336+
next_data_func = data_iter.__next__
337+
res_axes = array.axes[1:]
338+
# this case should not happen (handled by the fastpath in LArray.__iter__)
339+
assert len(res_axes) > 0
340+
341+
def next_func():
342+
return LArray(next_data_func(), res_axes)
343+
344+
self.__next__ = next_func
345+
346+
def __iter__(self):
347+
return self
333348

334349

335350
# TODO: rename to LArrayIndexIndexer or something like that
@@ -355,14 +370,41 @@ def _translate_key(self, key):
355370
for axis_key, axis in zip(key, self.array.axes))
356371

357372
def __getitem__(self, key):
358-
return self.array[self._translate_key(key)]
373+
ndim = self.array.ndim
374+
full_scalar_key = (
375+
(isinstance(key, (int, np.integer)) and ndim == 1) or
376+
(isinstance(key, tuple) and len(key) == ndim and all(isinstance(k, (int, np.integer)) for k in key))
377+
)
378+
# fast path when the result is a scalar
379+
if full_scalar_key:
380+
return self.array.data[key]
381+
else:
382+
return self.array[self._translate_key(key)]
359383

360384
def __setitem__(self, key, value):
361-
self.array[self._translate_key(key)] = value
385+
array = self.array
386+
ndim = array.ndim
387+
full_scalar_key = (
388+
(isinstance(key, (int, np.integer)) and ndim == 1) or
389+
(isinstance(key, tuple) and len(key) == ndim and all(isinstance(k, (int, np.integer)) for k in key))
390+
)
391+
# fast path when setting a single cell
392+
if full_scalar_key:
393+
array.data[key] = value
394+
else:
395+
array[self._translate_key(key)] = value
362396

363397
def __len__(self):
364398
return len(self.array)
365399

400+
def __iter__(self):
401+
array = self.array
402+
# fast path for 1D arrays (where we return scalars)
403+
if array.ndim <= 1:
404+
return iter(array.data)
405+
else:
406+
return LArrayIterator(array)
407+
366408

367409
class LArrayPointsIndexer(object):
368410
__slots__ = ('array',)
@@ -2696,6 +2738,7 @@ def _group_aggregate(self, op, items, keepaxes=False, out=None, **kwargs):
26962738
arr = np.asarray(arr)
26972739
op(arr, axis=axis_idx, out=out, **kwargs)
26982740
del arr
2741+
26992742
if killaxis:
27002743
assert group_idx[axis_idx] == 0
27012744
res_data = res_data[idx]

0 commit comments

Comments
 (0)