Skip to content

Commit 3577310

Browse files
committed
implemented LArray.iflat
1 parent e22f05f commit 3577310

File tree

3 files changed

+89
-1
lines changed

3 files changed

+89
-1
lines changed

doc/source/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ Modifying/Selecting
314314
LArray.i
315315
LArray.points
316316
LArray.ipoints
317+
LArray.iflat
317318
LArray.set
318319
LArray.drop
319320
LArray.ignore_labels

larray/core/array.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,87 @@ def __setitem__(self, key, value):
388388
self.array.__setitem__(self._prepare_key(key, wildcard=True), value, translate_key=False)
389389

390390

391+
# TODO: add support for slices
392+
# To select the first 4 values across all axes:
393+
#
394+
# >>> arr.iflat[:4]
395+
# a_b a0_b0 a0_b1 a0_b2 a1_b0
396+
# 0 10 20 30
397+
class LArrayFlatIndicesIndexer(object):
398+
r"""
399+
Access the array by index as if it was flat (one dimensional) and all its axes were combined.
400+
401+
Notes
402+
-----
403+
In general arr.iflat[key] should be equivalent to (but much faster than) arr.combine_axes().i[key]
404+
405+
Examples
406+
--------
407+
>>> arr = ndtest((2, 3)) * 10
408+
>>> arr
409+
a\b b0 b1 b2
410+
a0 0 10 20
411+
a1 30 40 50
412+
413+
To select the first, second, fourth and fifth values across all axes:
414+
415+
>>> arr.combine_axes().i[[0, 1, 3, 4]]
416+
a_b a0_b0 a0_b1 a1_b0 a1_b1
417+
0 10 30 40
418+
>>> arr.iflat[[0, 1, 3, 4]]
419+
a_b a0_b0 a0_b1 a1_b0 a1_b1
420+
0 10 30 40
421+
422+
Set the first and sixth values to 42
423+
424+
>>> arr.iflat[[0, 5]] = 42
425+
>>> arr
426+
a\b b0 b1 b2
427+
a0 42 10 20
428+
a1 30 40 42
429+
430+
When the key is an LArray, the result will have the axes of the key
431+
432+
>>> key = LArray([0, 3], 'c=c0,c1')
433+
>>> key
434+
c c0 c1
435+
0 3
436+
>>> arr.iflat[key]
437+
c c0 c1
438+
42 30
439+
"""
440+
__slots__ = ('array',)
441+
442+
def __init__(self, array):
443+
self.array = array
444+
445+
def __getitem__(self, flat_key, sep='_'):
446+
if isinstance(flat_key, ABCLArray):
447+
flat_np_key = flat_key.data
448+
res_axes = flat_key.axes
449+
else:
450+
flat_np_key = np.asarray(flat_key)
451+
axes = self.array.axes
452+
nd_key = np.unravel_index(flat_np_key, axes.shape)
453+
# the following lines are equivalent to (but faster than) "return array.ipoints[nd_key]"
454+
455+
# TODO: extract a function which only computes the combined axes because we do not use the actual LArrays
456+
# produced here, which is wasteful. AxisCollection._flat_lookup seems related (but not usable as-is).
457+
la_key = axes._adv_keys_to_combined_axis_la_keys(nd_key, sep=sep)
458+
first_axis_key_axes = la_key[0].axes
459+
assert all(isinstance(axis_key, ABCLArray) and axis_key.axes is first_axis_key_axes
460+
for axis_key in la_key[1:])
461+
res_axes = first_axis_key_axes
462+
return LArray(self.array.data.flat[flat_np_key], res_axes)
463+
464+
def __setitem__(self, flat_key, value):
465+
# np.ndarray.flat is a flatiter object but it is indexable despite the name
466+
self.array.data.flat[flat_key] = value
467+
468+
def __len__(self):
469+
return self.array.size
470+
471+
391472
# TODO: rename to LArrayIndexPointsIndexer or something like that
392473
class LArrayPositionalPointsIndexer(object):
393474
__slots__ = ('array',)
@@ -3420,6 +3501,11 @@ def items(self, axes=None, ascending=True):
34203501
"""
34213502
return SequenceZip((self.keys(axes, ascending=ascending), self.values(axes, ascending=ascending)))
34223503

3504+
@lazy_attribute
3505+
def iflat(self):
3506+
return LArrayFlatIndicesIndexer(self)
3507+
iflat.__doc__ = LArrayFlatIndicesIndexer.__doc__
3508+
34233509
def copy(self):
34243510
"""Returns a copy of the array.
34253511
"""

larray/core/axis.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3378,10 +3378,11 @@ def _adv_keys_to_combined_axis_la_keys(self, key, wildcard=False, sep='_'):
33783378
sepjoin = sep.join
33793379
combined_labels = [sepjoin(comb) for comb in zip(*axes_labels)]
33803380
combined_axis = Axis(combined_labels, combined_name)
3381+
combined_axes = AxisCollection(combined_axis)
33813382

33823383
# 2) transform all advanced non-LArray keys to LArray with the combined axis
33833384
# ==========================================================================
3384-
return tuple(axis_key if isinstance(axis_key, ignored_types) else LArray(axis_key, combined_axis)
3385+
return tuple(axis_key if isinstance(axis_key, ignored_types) else LArray(axis_key, combined_axes)
33853386
for axis_key in key)
33863387

33873388

0 commit comments

Comments
 (0)