Skip to content

Commit 3dbf6d6

Browse files
authored
Multiple outputs (#419)
* Multiple outputs * Get general_blockwise core op working for multiple outputs * Handle fusion of multiple outputs Test for child fusion (where multiple outputs op is fused with its two children) Test for child fusion (where multiple outputs op is fused with its two children) test_fuse_multiple_outputs_diamond sibling fusion test * Mem utilization test for multiple outputs * Allow multiple output functions to just return a tuple * Fix for Zarr v3
1 parent bafa0b3 commit 3dbf6d6

File tree

10 files changed

+448
-118
lines changed

10 files changed

+448
-118
lines changed

cubed/array_api/manipulation_functions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,9 @@ def key_function(out_key):
321321
key_function,
322322
x,
323323
template,
324-
shape=shape,
325-
dtype=x.dtype,
326-
chunks=outchunks,
324+
shapes=[shape],
325+
dtypes=[x.dtype],
326+
chunkss=[outchunks],
327327
)
328328

329329

@@ -402,9 +402,9 @@ def key_function(out_key):
402402
_read_stack_chunk,
403403
key_function,
404404
*arrays,
405-
shape=shape,
406-
dtype=dtype,
407-
chunks=chunks,
405+
shapes=[shape],
406+
dtypes=[dtype],
407+
chunkss=[chunks],
408408
axis=axis,
409409
fusable=False,
410410
)

cubed/core/ops.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from itertools import product
66
from numbers import Integral, Number
77
from operator import add
8-
from typing import TYPE_CHECKING, Any, Sequence, Union
8+
from typing import TYPE_CHECKING, Any, Sequence, Tuple, Union
99
from warnings import warn
1010

1111
import ndindex
@@ -333,14 +333,14 @@ def general_blockwise(
333333
func,
334334
key_function,
335335
*arrays,
336-
shape,
337-
dtype,
338-
chunks,
339-
target_store=None,
340-
target_path=None,
336+
shapes,
337+
dtypes,
338+
chunkss,
339+
target_stores=None,
340+
target_paths=None,
341341
extra_func_kwargs=None,
342342
**kwargs,
343-
) -> "Array":
343+
) -> Union["Array", Tuple["Array", ...]]:
344344
assert len(arrays) > 0
345345

346346
# replace arrays with zarr arrays
@@ -354,24 +354,33 @@ def general_blockwise(
354354

355355
num_input_blocks = kwargs.pop("num_input_blocks", None)
356356

357-
name = gensym()
358357
spec = check_array_specs(arrays)
359-
if target_store is None:
360-
target_store = new_temp_path(name=name, spec=spec)
358+
359+
if isinstance(target_stores, list): # multiple outputs
360+
name = [gensym() for _ in range(len(target_stores))]
361+
target_stores = [
362+
ts if ts is not None else new_temp_path(name=n, spec=spec)
363+
for n, ts in zip(name, target_stores)
364+
]
365+
else: # single output
366+
name = gensym()
367+
if target_stores is None:
368+
target_stores = [new_temp_path(name=name, spec=spec)]
369+
361370
op = primitive_general_blockwise(
362371
func,
363372
key_function,
364373
*zargs,
365374
allowed_mem=spec.allowed_mem,
366375
reserved_mem=spec.reserved_mem,
367376
extra_projected_mem=extra_projected_mem,
368-
target_store=target_store,
369-
target_path=target_path,
377+
target_stores=target_stores,
378+
target_paths=target_paths,
370379
storage_options=spec.storage_options,
371380
compressor=spec.zarr_compressor,
372-
shape=shape,
373-
dtype=dtype,
374-
chunks=chunks,
381+
shapes=shapes,
382+
dtypes=dtypes,
383+
chunkss=chunkss,
375384
in_names=in_names,
376385
extra_func_kwargs=extra_func_kwargs,
377386
num_input_blocks=num_input_blocks,
@@ -387,7 +396,10 @@ def general_blockwise(
387396
)
388397
from cubed.array_api import Array
389398

390-
return Array(name, op.target_array, spec, plan)
399+
if isinstance(op.target_array, list): # multiple outputs
400+
return tuple(Array(n, ta, spec, plan) for n, ta in zip(name, op.target_array))
401+
else: # single output
402+
return Array(name, op.target_array, spec, plan)
391403

392404

393405
def elemwise(func, *args: "Array", dtype=None) -> "Array":
@@ -914,9 +926,9 @@ def key_function(out_key):
914926
_concatenate2,
915927
key_function,
916928
x,
917-
shape=x.shape,
918-
dtype=x.dtype,
919-
chunks=target_chunks,
929+
shapes=[x.shape],
930+
dtypes=[x.dtype],
931+
chunkss=[target_chunks],
920932
extra_projected_mem=0,
921933
num_input_blocks=(num_input_blocks,),
922934
axes=axes,
@@ -1229,12 +1241,12 @@ def partial_reduce(
12291241
axis = tuple(ax for ax in split_every.keys())
12301242
combine_sizes = combine_sizes or {}
12311243
combine_sizes = {k: combine_sizes.get(k, 1) for k in axis}
1232-
chunks = [
1244+
chunks = tuple(
12331245
(combine_sizes[i],) * math.ceil(len(c) / split_every[i])
12341246
if i in split_every
12351247
else c
12361248
for (i, c) in enumerate(x.chunks)
1237-
]
1249+
)
12381250
shape = tuple(map(sum, chunks))
12391251

12401252
def key_function(out_key):
@@ -1263,9 +1275,9 @@ def key_function(out_key):
12631275
_partial_reduce,
12641276
key_function,
12651277
x,
1266-
shape=shape,
1267-
dtype=dtype,
1268-
chunks=chunks,
1278+
shapes=[shape],
1279+
dtypes=[dtype],
1280+
chunkss=[chunks],
12691281
extra_projected_mem=extra_projected_mem,
12701282
num_input_blocks=(sum(split_every.values()),),
12711283
reduce_func=func,

cubed/core/optimization.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def can_fuse(n):
3131
if "primitive_op" not in nodes[op2]:
3232
return False
3333

34-
# if node (op2) does not have exactly one input then don't fuse
34+
# if node (op2) does not have exactly one input and output then don't fuse
3535
# (it could have no inputs or multiple inputs)
36-
if dag.in_degree(op2) != 1:
36+
if dag.in_degree(op2) != 1 or dag.out_degree(op2) != 1:
3737
return False
3838

3939
# if input is one of the arrays being computed then don't fuse
@@ -91,6 +91,12 @@ def predecessors_unordered(dag, name):
9191
yield pre
9292

9393

94+
def successors_unordered(dag, name):
95+
"""Return a node's successors in no particular order, with repeats for multiple edges."""
96+
for pre, _ in dag.out_edges(name):
97+
yield pre
98+
99+
94100
def predecessor_ops(dag, name):
95101
"""Return an op node's op predecessors in the same order as the input source arrays for the op.
96102
@@ -183,6 +189,17 @@ def can_fuse_predecessors(
183189
)
184190
return False
185191

192+
# if any predecessor ops have multiple outputs then don't fuse
193+
# TODO: implement "child fusion" (where a multiple output op fuses its children)
194+
if any(
195+
len(list(successors_unordered(dag, pre))) > 1
196+
for pre in predecessor_ops(dag, name)
197+
):
198+
logger.debug(
199+
"can't fuse %s since at least one predecessor has multiple outputs", name
200+
)
201+
return False
202+
186203
# if node is in never_fuse or always_fuse list then it overrides logic below
187204
if never_fuse is not None and name in never_fuse:
188205
logger.debug("can't fuse %s since it is in 'never_fuse'", name)

cubed/core/plan.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class Plan:
7373
def __init__(self, dag):
7474
self.dag = dag
7575

76-
# args from pipeline onwards are omitted for creation functions when no computation is needed
76+
# args from primitive_op onwards are omitted for creation functions when no computation is needed
7777
@classmethod
7878
def _new(
7979
cls,
@@ -110,15 +110,26 @@ def _new(
110110
op_display_name=f"{op_name_unique}\n{first_cubed_summary.name}",
111111
hidden=hidden,
112112
)
113-
# array (when multiple outputs are supported there could be more than one)
114-
dag.add_node(
115-
name,
116-
name=name,
117-
type="array",
118-
target=target,
119-
hidden=hidden,
120-
)
121-
dag.add_edge(op_name_unique, name)
113+
# array
114+
if isinstance(name, list): # multiple outputs
115+
for n, t in zip(name, target):
116+
dag.add_node(
117+
n,
118+
name=n,
119+
type="array",
120+
target=t,
121+
hidden=hidden,
122+
)
123+
dag.add_edge(op_name_unique, n)
124+
else: # single output
125+
dag.add_node(
126+
name,
127+
name=name,
128+
type="array",
129+
target=target,
130+
hidden=hidden,
131+
)
132+
dag.add_edge(op_name_unique, name)
122133
else:
123134
# op
124135
dag.add_node(
@@ -132,15 +143,26 @@ def _new(
132143
primitive_op=primitive_op,
133144
pipeline=primitive_op.pipeline,
134145
)
135-
# array (when multiple outputs are supported there could be more than one)
136-
dag.add_node(
137-
name,
138-
name=name,
139-
type="array",
140-
target=target,
141-
hidden=hidden,
142-
)
143-
dag.add_edge(op_name_unique, name)
146+
# array
147+
if isinstance(name, list): # multiple outputs
148+
for n, t in zip(name, target):
149+
dag.add_node(
150+
n,
151+
name=n,
152+
type="array",
153+
target=t,
154+
hidden=hidden,
155+
)
156+
dag.add_edge(op_name_unique, n)
157+
else: # single output
158+
dag.add_node(
159+
name,
160+
name=name,
161+
type="array",
162+
target=target,
163+
hidden=hidden,
164+
)
165+
dag.add_edge(op_name_unique, name)
144166
for x in source_arrays:
145167
if hasattr(x, "name"):
146168
dag.add_edge(x.name, op_name_unique)

0 commit comments

Comments
 (0)