Skip to content

Commit 4b18053

Browse files
authored
Change default max_total_num_input_blocks to 10 (#615)
To allow operations to be fused as long as the total number of input blocks does not exceed this number. Previously, the default was None, which meant operations were fused only if they had the same number of tasks.
1 parent b3396a9 commit 4b18053

File tree

3 files changed

+44
-19
lines changed

3 files changed

+44
-19
lines changed

cubed/core/optimization.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
logger = logging.getLogger(__name__)
1313

14+
DEFAULT_MAX_TOTAL_SOURCE_ARRAYS = 4
15+
DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS = 10
16+
1417

1518
def simple_optimize_dag(dag, array_names=None):
1619
"""Apply map blocks fusion."""
@@ -154,8 +157,8 @@ def can_fuse_predecessors(
154157
name,
155158
*,
156159
array_names=None,
157-
max_total_source_arrays=4,
158-
max_total_num_input_blocks=None,
160+
max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS,
161+
max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS,
159162
always_fuse=None,
160163
never_fuse=None,
161164
):
@@ -242,8 +245,8 @@ def fuse_predecessors(
242245
name,
243246
*,
244247
array_names=None,
245-
max_total_source_arrays=4,
246-
max_total_num_input_blocks=None,
248+
max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS,
249+
max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS,
247250
always_fuse=None,
248251
never_fuse=None,
249252
):
@@ -297,8 +300,8 @@ def multiple_inputs_optimize_dag(
297300
dag,
298301
*,
299302
array_names=None,
300-
max_total_source_arrays=4,
301-
max_total_num_input_blocks=None,
303+
max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS,
304+
max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS,
302305
always_fuse=None,
303306
never_fuse=None,
304307
):

cubed/tests/test_optimization.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from cubed.backend_array_api import namespace as nxp
1212
from cubed.core.ops import elemwise, merge_chunks, partial_reduce
1313
from cubed.core.optimization import (
14+
DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS,
15+
DEFAULT_MAX_TOTAL_SOURCE_ARRAYS,
1416
fuse_all_optimize_dag,
1517
fuse_only_optimize_dag,
1618
fuse_predecessors,
@@ -223,7 +225,11 @@ def fuse_one_level(arr, *, always_fuse=None):
223225
)
224226

225227

226-
def fuse_multiple_levels(*, max_total_source_arrays=4, max_total_num_input_blocks=None):
228+
def fuse_multiple_levels(
229+
*,
230+
max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS,
231+
max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS,
232+
):
227233
# use multiple_inputs_optimize_dag to test multiple levels of fusion
228234
return partial(
229235
multiple_inputs_optimize_dag,
@@ -899,8 +905,7 @@ def test_fuse_merge_chunks_unary(spec):
899905
b = xp.negative(a)
900906
c = merge_chunks(b, chunks=(3, 2))
901907

902-
# specify max_total_num_input_blocks to force c to fuse
903-
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=3)
908+
opt_fn = fuse_multiple_levels()
904909

905910
c.visualize(optimize_function=opt_fn)
906911

@@ -921,6 +926,16 @@ def test_fuse_merge_chunks_unary(spec):
921926
result = c.compute(optimize_function=opt_fn)
922927
assert_array_equal(result, -np.ones((3, 2)))
923928

929+
# now set max_total_num_input_blocks=None which means
930+
# "only fuse if ops have same number of tasks", which they don't here
931+
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=None)
932+
optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag
933+
934+
# merge_chunks uses a hidden op and array for block ids - ignore when comparing structure
935+
assert not structurally_equivalent(
936+
optimized_dag, expected_fused_dag, remove_hidden=True
937+
)
938+
924939

925940
# merge chunks with different number of tasks (c has more tasks than d)
926941
#
@@ -936,8 +951,7 @@ def test_fuse_merge_chunks_binary(spec):
936951
c = xp.add(a, b)
937952
d = merge_chunks(c, chunks=(3, 2))
938953

939-
# specify max_total_num_input_blocks to force d to fuse
940-
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=6)
954+
opt_fn = fuse_multiple_levels()
941955

942956
d.visualize(optimize_function=opt_fn)
943957

@@ -963,15 +977,24 @@ def test_fuse_merge_chunks_binary(spec):
963977
result = d.compute(optimize_function=opt_fn)
964978
assert_array_equal(result, 2 * np.ones((3, 2)))
965979

980+
# now set max_total_num_input_blocks=None which means
981+
# "only fuse if ops have same number of tasks", which they don't here
982+
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=None)
983+
optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag
984+
985+
# merge_chunks uses a hidden op and array for block ids - ignore when comparing structure
986+
assert not structurally_equivalent(
987+
optimized_dag, expected_fused_dag, remove_hidden=True
988+
)
989+
966990

967991
# like test_fuse_merge_chunks_unary, except uses partial_reduce
968992
def test_fuse_partial_reduce_unary(spec):
969993
a = xp.ones((3, 2), chunks=(1, 2), spec=spec)
970994
b = xp.negative(a)
971995
c = partial_reduce(b, np.sum, split_every={0: 3})
972996

973-
# specify max_total_num_input_blocks to force c to fuse
974-
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=3)
997+
opt_fn = fuse_multiple_levels()
975998

976999
c.visualize(optimize_function=opt_fn)
9771000

@@ -996,8 +1019,7 @@ def test_fuse_partial_reduce_binary(spec):
9961019
c = xp.add(a, b)
9971020
d = partial_reduce(c, np.sum, split_every={0: 3})
9981021

999-
# specify max_total_num_input_blocks to force d to fuse
1000-
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=6)
1022+
opt_fn = fuse_multiple_levels()
10011023

10021024
d.visualize(optimize_function=opt_fn)
10031025

@@ -1176,7 +1198,7 @@ def test_optimize_stack(spec):
11761198
c = xp.stack((a, b), axis=0)
11771199
d = c + 1
11781200
# try to fuse all ops into one (d will fuse with c, but c won't fuse with a and b)
1179-
d.compute(optimize_function=fuse_multiple_levels(max_total_num_input_blocks=10))
1201+
d.compute(optimize_function=fuse_multiple_levels())
11801202

11811203

11821204
def test_optimize_concat(spec):
@@ -1186,4 +1208,4 @@ def test_optimize_concat(spec):
11861208
c = xp.concat((a, b), axis=0)
11871209
d = c + 1
11881210
# try to fuse all ops into one (d will fuse with c, but c won't fuse with a and b)
1189-
d.compute(optimize_function=fuse_multiple_levels(max_total_num_input_blocks=10))
1211+
d.compute(optimize_function=fuse_multiple_levels())

docs/user-guide/optimization.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ e.visualize(optimize_function=opt_fn)
112112

113113
The `max_total_num_input_blocks` argument to `multiple_inputs_optimize_dag` specifies the maximum number of input blocks (chunks) that are allowed in the fused operation.
114114

115-
Again, this is to limit the number of reads that an individual task must perform. The default is `None`, which means that operations are fused only if they have the same number of tasks. If set to an integer, then this limitation is removed, and tasks with a different number of tasks will be fused - as long as the total number of input blocks does not exceed the maximum. This setting is useful for reductions, and can be set using `functools.partial`:
115+
Again, this is to limit the number of reads that an individual task must perform. If set to `None`, operations are fused only if they have the same number of tasks. If set to an integer (the default is 10), then tasks with a different number of tasks will be fused - as long as the total number of input blocks does not exceed the maximum. This setting is useful for reductions, and can be changed using `functools.partial`:
116116

117117
```python
118-
opt_fn = partial(multiple_inputs_optimize_dag, max_total_num_input_blocks=10)
118+
opt_fn = partial(multiple_inputs_optimize_dag, max_total_num_input_blocks=20)
119119
e.visualize(optimize_function=opt_fn)
120120
```

0 commit comments

Comments
 (0)