Skip to content

Commit d6021a6

Browse files
committed
faster LArray.apply when the function returns several values
1 parent 084b460 commit d6021a6

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

larray/core/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7910,8 +7910,8 @@ def apply(self, transform, *args, **kwargs):
79107910
if isinstance(first_value, tuple):
79117911
# assume all other values are the same shape
79127912
tuple_length = len(first_value)
7913-
# TODO: compute res_axes (potentially different for each return value) in this case too
7914-
res_arrays = [stack([(key, value[i]) for key, value in key_values], axes=by, dtype=dtype)
7913+
res_arrays = [stack([(key, value[i]) for key, value in key_values], axes=by, dtype=dtype,
7914+
res_axes=get_axes(first_value[i]).union(by))
79157915
for i in range(tuple_length)]
79167916
# transpose back axis where it was
79177917
return tuple(res_arr.transpose(self.axes & res_arr.axes) for res_arr in res_arrays)

0 commit comments

Comments
 (0)