Skip to content

Commit 50181e0

Browse files
authored
Implement unstack using multiple outputs (#575)
1 parent 8a406dc commit 50181e0

File tree

5 files changed

+75
-0
lines changed

5 files changed

+75
-0
lines changed

cubed/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@
286286
roll,
287287
squeeze,
288288
stack,
289+
unstack,
289290
)
290291

291292
__all__ += [
@@ -300,6 +301,7 @@
300301
"roll",
301302
"squeeze",
302303
"stack",
304+
"unstack",
303305
]
304306

305307
from .array_api.searching_functions import argmax, argmin, where

cubed/array_api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@
228228
roll,
229229
squeeze,
230230
stack,
231+
unstack,
231232
)
232233

233234
__all__ += [
@@ -242,6 +243,7 @@
242243
"roll",
243244
"squeeze",
244245
"stack",
246+
"unstack",
245247
]
246248

247249
from .searching_functions import argmax, argmin, where

cubed/array_api/manipulation_functions.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,3 +412,43 @@ def key_function(out_key):
412412

413413
def _read_stack_chunk(array, axis=None):
414414
return nxp.expand_dims(array, axis=axis)
415+
416+
417+
def unstack(x, /, *, axis=0):
418+
axis = validate_axis(axis, x.ndim)
419+
420+
n_arrays = x.shape[axis]
421+
422+
if n_arrays == 1:
423+
return (x,)
424+
425+
shape = x.shape[:axis] + x.shape[axis + 1 :]
426+
dtype = x.dtype
427+
chunks = x.chunks[:axis] + x.chunks[axis + 1 :]
428+
429+
def key_function(out_key):
430+
out_coords = out_key[1:]
431+
all_in_coords = tuple(
432+
out_coords[:axis] + (i,) + out_coords[axis:]
433+
for i in range(x.numblocks[axis])
434+
)
435+
return tuple((x.name,) + in_coords for in_coords in all_in_coords)
436+
437+
return general_blockwise(
438+
_unstack_chunk,
439+
key_function,
440+
x,
441+
shapes=[shape] * n_arrays,
442+
dtypes=[dtype] * n_arrays,
443+
chunkss=[chunks] * n_arrays,
444+
target_stores=[None] * n_arrays, # filled in by general_blockwise
445+
axis=axis,
446+
)
447+
448+
449+
def _unstack_chunk(*arrs, axis=0):
450+
# unstack each array in arrs and yield all in turn
451+
for arr in arrs:
452+
# TODO: replace with nxp.unstack(arr, axis=axis) when array-api-compat has unstack
453+
for a in tuple(nxp.moveaxis(arr, axis, 0)):
454+
yield a

cubed/tests/test_array_api.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,28 @@ def test_stack(spec, executor):
627627
)
628628

629629

630+
@pytest.mark.parametrize("chunks", [(1, 2, 3), (2, 2, 3), (3, 2, 3)])
631+
def test_unstack(spec, executor, chunks):
632+
a = xp.full((4, 6), 1, chunks=(2, 3), spec=spec)
633+
b = xp.full((4, 6), 2, chunks=(2, 3), spec=spec)
634+
c = xp.full((4, 6), 3, chunks=(2, 3), spec=spec)
635+
d = xp.stack([a, b, c], axis=0)
636+
637+
d = d.rechunk(chunks)
638+
639+
au, bu, cu = cubed.compute(*xp.unstack(d), executor=executor, optimize_graph=False)
640+
641+
assert_array_equal(au, np.full((4, 6), 1))
642+
assert_array_equal(bu, np.full((4, 6), 2))
643+
assert_array_equal(cu, np.full((4, 6), 3))
644+
645+
646+
def test_unstack_noop(spec):
647+
a = xp.full((1, 4, 6), 1, chunks=(1, 2, 3), spec=spec)
648+
(b,) = xp.unstack(a)
649+
assert a is b
650+
651+
630652
# Searching functions
631653

632654

cubed/tests/test_mem_utilization.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,15 @@ def test_stack(tmp_path, spec, executor):
279279
run_operation(tmp_path, executor, "stack", c)
280280

281281

282+
@pytest.mark.slow
283+
def test_unstack(tmp_path, spec, executor):
284+
a = cubed.random.random(
285+
(2, 10000, 10000), chunks=(2, 5000, 5000), spec=spec
286+
) # 400MB chunks
287+
b, c = xp.unstack(a)
288+
run_operation(tmp_path, executor, "unstack", b, c)
289+
290+
282291
# Searching Functions
283292

284293

0 commit comments

Comments
 (0)