Skip to content

Commit 1e42342

Browse files
committed
simplified _adv_keys_to_combined_axis_la_keys
1 parent f8cd41d commit 1e42342

File tree

1 file changed

+17
-22
lines changed

1 file changed

+17
-22
lines changed

larray/core/axis.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3207,25 +3207,27 @@ def _adv_keys_to_combined_axis_la_keys(self, key, wildcard=False, sep='_'):
32073207
# TODO: use/factorize with AxisCollection.combine_axes. The problem is that it uses product(*axes_labels)
32083208
# while here we need zip(*axes_labels)
32093209
ignored_types = (int, np.integer, slice, LArray)
3210-
adv_key_axes = [axis for axis_key, axis in zip(key, self)
3211-
if not isinstance(axis_key, ignored_types)]
3212-
if not adv_key_axes:
3210+
adv_keys = [(axis_key, axis) for axis_key, axis in zip(key, self)
3211+
if not isinstance(axis_key, ignored_types)]
3212+
if not adv_keys:
32133213
return key
32143214

32153215
# axes with a scalar key are not taken, since we want to kill them
32163216

32173217
# all anonymous axes => anonymous combined axis
3218-
if all(axis.name is None for axis in adv_key_axes):
3218+
if all(axis.name is None for axis_key, axis in adv_keys):
32193219
combined_name = None
32203220
else:
32213221
# using axis_id instead of name to allow combining a mix of anonymous & non anonymous axes
3222-
combined_name = sep.join(str(self.axis_id(axis)) for axis in adv_key_axes)
3222+
combined_name = sep.join(str(self.axis_id(axis)) for axis_key, axis in adv_keys)
3223+
3224+
# explicitly check that all combined keys have the same length
3225+
first_key, first_axis = adv_keys[0]
3226+
combined_axis_len = len(first_key)
3227+
if not all(len(axis_key) == combined_axis_len for axis_key, axis in adv_keys[1:]):
3228+
raise ValueError("all combined keys should have the same length")
32233229

32243230
if wildcard:
3225-
lengths = [len(axis_key) for axis_key in key
3226-
if not isinstance(axis_key, ignored_types)]
3227-
combined_axis_len = lengths[0]
3228-
assert all(l == combined_axis_len for l in lengths)
32293231
combined_axis = Axis(combined_axis_len, combined_name)
32303232
else:
32313233
# TODO: the combined keys should be objects which display as:
@@ -3235,31 +3237,24 @@ def _adv_keys_to_combined_axis_la_keys(self, key, wildcard=False, sep='_'):
32353237
# A: yes, probably. On the Pandas backend, we could/should have
32363238
# separate axes. On the numpy backend we cannot.
32373239
# TODO: only convert if
3238-
if len(adv_key_axes) == 1:
3239-
# we don't convert to string when there is only a single axis
3240+
if len(adv_keys) == 1:
3241+
# we do not convert to string when there is only a single axis
32403242
axes_labels = [axis.labels[axis_key]
3241-
for axis_key, axis in zip(key, self)
3242-
if not isinstance(axis_key, ignored_types)]
3243+
for axis_key, axis in adv_keys]
32433244
# Q: if axis is a wildcard axis, should the result be a
32443245
# wildcard axis (and axes_labels discarded?)
32453246
combined_labels = axes_labels[0]
32463247
else:
32473248
axes_labels = [axis.labels.astype(np.str, copy=False)[axis_key].tolist()
3248-
for axis_key, axis in zip(key, self)
3249-
if not isinstance(axis_key, ignored_types)]
3249+
for axis_key, axis in adv_keys]
32503250
sepjoin = sep.join
32513251
combined_labels = [sepjoin(comb) for comb in zip(*axes_labels)]
32523252
combined_axis = Axis(combined_labels, combined_name)
32533253

32543254
# 2) transform all advanced non-LArray keys to LArray with the combined axis
32553255
# ==========================================================================
3256-
def to_la_key(axis_key, combined_axis):
3257-
if isinstance(axis_key, (int, np.integer, slice, LArray)):
3258-
return axis_key
3259-
else:
3260-
assert len(axis_key) == len(combined_axis)
3261-
return LArray(axis_key, combined_axis)
3262-
return tuple(to_la_key(axis_key, combined_axis) for axis_key in key)
3256+
return tuple(axis_key if isinstance(axis_key, ignored_types) else LArray(axis_key, combined_axis)
3257+
for axis_key in key)
32633258

32643259

32653260
class AxisReference(ABCAxisReference, ExprNode, Axis):

0 commit comments

Comments
 (0)