Skip to content

Commit fcf5115

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas Fuser] Add output_fusion_mask support
Currently, the fusion API assumes by default that all of the outputs of a @fuse-decorated function are computed jointly in one big output fusion. For example, in the following snippet ```python @fuse def f(x, y): z1, z2 = fusable_f(x, y) return g(z1, z2) ``` it assumes that `g` is a single function that operates on z1 and z2 jointly. However, in practice, the fusable may want two separate output fusions: ```python @fuse def f(x, y): z1, z2 = fusable_f(x, y) return g1(z1), g2(z2) ``` This is a special case of the general function but the fusable may not be materializing z1 and z2 at the same time so may not be able to compute this efficiently with a single function g. By decorating a fusable with an output fusion prefix (in the above example `(True, True)`), the fusable will now be given a pair of functions `g1` and `g2` if the output fusion is "separable". For example, we'd error for the following example: ```python @fuse def f(x, y): z1, z2 = fusable_f(x, y) return z1 + z2 ``` because z1 and z2 interact with each other in the output fusion. The rationale for providing a PyTree prefix (as opposed to a more general mechanism) is that the fusable can group its outputs into subtrees that it can identify with the output prefix. This does restrict the types of output groups that are possible (outputs must be part of the same shared subtree, as opposed to arbitrarily scattered throughput the output pytree), but this is an okay restriction because the fusable author is responsible for the grouping and can always construct it that way. PiperOrigin-RevId: 744784770
1 parent 855829e commit fcf5115

File tree

6 files changed

+469
-56
lines changed

6 files changed

+469
-56
lines changed

jax/_src/pallas/fuser/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ pytype_strict_library(
9999
"//jax:core",
100100
"//jax:partial_eval",
101101
"//jax:tree_util",
102+
"//jax:util",
102103
],
103104
)
104105

jax/_src/pallas/fuser/fusable.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Fusable primitive."""
16+
from typing import Any
1617

1718
import jax
1819
from jax._src import api_util
@@ -40,32 +41,38 @@ def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion:
4041
)
4142

4243

43-
def fusable(f):
44-
def wrapper(*args):
45-
def wrapped(*args):
46-
in_fusions = tree_util.tree_map(_make_trivial_fusion, args)
47-
return f(*in_fusions, None)
48-
49-
flat_args, in_tree = tree_util.tree_flatten(args)
50-
debug_info = api_util.debug_info('fusable', wrapped, args, {})
51-
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
52-
lu.wrap_init(wrapped, debug_info=debug_info), in_tree
53-
)
54-
flat_avals = [_get_aval(x) for x in flat_args]
55-
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
56-
out_tree = out_tree_thunk()
57-
out = fusable_p.bind(
58-
*consts,
59-
*flat_args,
60-
jaxpr=jaxpr,
61-
num_consts=len(consts),
62-
in_tree=in_tree,
63-
out_tree=out_tree,
64-
func=f,
65-
)
66-
return tree_util.tree_unflatten(out_tree, out)
67-
68-
return wrapper
44+
def fusable(f=None, *, output_fusion_prefix: Any = True):
45+
def decorator(f):
46+
def wrapper(*args):
47+
def wrapped(*args):
48+
in_fusions = tree_util.tree_map(_make_trivial_fusion, args)
49+
return f(*in_fusions, None)
50+
51+
flat_args, in_tree = tree_util.tree_flatten(args)
52+
debug_info = api_util.debug_info('fusable', wrapped, args, {})
53+
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
54+
lu.wrap_init(wrapped, debug_info=debug_info), in_tree
55+
)
56+
flat_avals = [_get_aval(x) for x in flat_args]
57+
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
58+
out_tree = out_tree_thunk()
59+
out = fusable_p.bind(
60+
*consts,
61+
*flat_args,
62+
jaxpr=jaxpr,
63+
num_consts=len(consts),
64+
in_tree=in_tree,
65+
out_tree=out_tree,
66+
func=f,
67+
output_fusion_prefix=output_fusion_prefix,
68+
)
69+
return tree_util.tree_unflatten(out_tree, out)
70+
71+
return wrapper
72+
73+
if f is not None:
74+
return decorator(f)
75+
return decorator
6976

7077

7178
@fusable_p.def_impl

jax/_src/pallas/fuser/jaxpr_fusion.py

Lines changed: 177 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414

1515
"""Fuses a function."""
1616

17+
from collections.abc import Sequence
18+
import functools
1719
from typing import Any
18-
1920
import jax
2021
from jax._src import api_util
2122
from jax._src import core as jax_core
2223
from jax._src import linear_util as lu
2324
from jax._src import tree_util
2425
from jax._src.interpreters import partial_eval as pe
25-
2626
from jax._src.pallas.fuser import fusable_dtype
2727
from jax._src.pallas.fuser import fusion as fusion_lib
2828
from jax._src.pallas.fuser.fusable import fusable_p
@@ -73,9 +73,9 @@ def wrapper(*args, **kwargs):
7373
_fusable: dict[jax_core.Primitive, Any] = {}
7474

7575

76-
def construct_fusion(
76+
def _construct_fusion_jaxpr(
7777
candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs
78-
) -> fusion_lib.Fusion:
78+
):
7979
flat_outvars, out_tree = tree_util.tree_flatten(outvars)
8080
flat_invars, in_tree = tree_util.tree_flatten((invars, kwargs))
8181
new_jaxpr_no_dce = jaxpr.replace(
@@ -94,12 +94,6 @@ def construct_fusion(
9494
c for used, c in zip(used_consts, candidate_values, strict=True) if used
9595
)
9696
kernel_in_tree = tree_util.tree_structure((invars, kwargs))
97-
98-
def _fn(*args, **kwargs):
99-
flat_args, _ = tree_util.tree_flatten((args, kwargs))
100-
out_flat = jax_core.eval_jaxpr(new_jaxpr, new_values, *flat_args)
101-
return tree_util.tree_unflatten(out_tree, out_flat)
102-
10397
flat_in_type = [
10498
jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_invars
10599
]
@@ -108,9 +102,158 @@ def _fn(*args, **kwargs):
108102
out_tree,
109103
[jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_outvars],
110104
)
105+
return new_jaxpr, new_values, in_type, out_type, out_tree
106+
107+
108+
def construct_fusion(
109+
candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs
110+
) -> fusion_lib.Fusion:
111+
new_jaxpr, new_values, in_type, out_type, out_tree = _construct_fusion_jaxpr(
112+
candidate_values, jaxpr, outvars, *invars, **kwargs
113+
)
114+
115+
def _fn(*args, **kwargs):
116+
flat_args, _ = tree_util.tree_flatten((args, kwargs))
117+
out_flat = jax_core.eval_jaxpr(new_jaxpr, new_values, *flat_args)
118+
return tree_util.tree_unflatten(out_tree, out_flat)
119+
111120
return fusion_lib.Fusion(_fn, in_type, out_type)
112121

113122

123+
def _find_downstream(
124+
jaxpr: jax_core.Jaxpr, in_used: Sequence[bool]
125+
) -> tuple[bool, ...]:
126+
# TODO(sharadmv): We use partial_eval to query downstream dependencies which
127+
# is not an officially sanctioned way to do so, since PE is really used for
128+
# AD. In the future, we should have a special Jaxpr API that queries this.
129+
_, _, out_used, *_ = pe.partial_eval_jaxpr_custom(
130+
jaxpr,
131+
in_unknowns=in_used,
132+
in_inst=in_used,
133+
ensure_out_unknowns=False,
134+
ensure_out_inst=False,
135+
saveable=lambda *_, **__: False,
136+
)
137+
return tuple(out_used)
138+
139+
140+
def _construct_output_permutation(
141+
used: list[tuple[bool, ...]],
142+
) -> list[int]:
143+
order = []
144+
for u in used:
145+
true_vals = [i for i in range(len(u)) if u[i]]
146+
order.extend(true_vals)
147+
return [order.index(i) for i in range(len(order))]
148+
149+
150+
def _construct_output_fusions(
151+
candidate_values,
152+
jaxpr,
153+
out_tree,
154+
fusion_eqn_index,
155+
fusion_eqn_outvars, # Flat list of vars output by the fusable eqn
156+
fusion_eqn_out_tree, # Tree structure of the fusable eqn outputs
157+
output_fusion_prefix, # Pytree defining output groups
158+
):
159+
# 1. Create jaxpr_out: represents computation *after* the fusable
160+
# Inputs: fusion_eqn_outvars
161+
# Outputs: jaxpr.outvars
162+
jaxpr_out, all_values, _, _, _ = _construct_fusion_jaxpr(
163+
candidate_values,
164+
jaxpr.replace(
165+
eqns=jaxpr.eqns[:fusion_eqn_index]
166+
+ jaxpr.eqns[fusion_eqn_index + 1 :]
167+
),
168+
tree_util.tree_unflatten(out_tree, jaxpr.outvars), # Original outputs
169+
tree_util.tree_unflatten(
170+
fusion_eqn_out_tree, fusion_eqn_outvars
171+
), # Fusable outputs as inputs
172+
)
173+
174+
# 2. Group fusable outputs based on the mask
175+
unflat_fusable_outvars = jax.tree.unflatten(
176+
fusion_eqn_out_tree, fusion_eqn_outvars
177+
)
178+
partial_flat = jax.tree.structure(output_fusion_prefix).flatten_up_to(
179+
unflat_fusable_outvars
180+
)
181+
182+
# 3. Calculate dependencies and check disjointness
183+
downstream_outputs_used_masks = [] # List of bool tuples, one per group
184+
already_used_final_outputs = set() # Indices of final outputs already claimed
185+
for outvars_group in partial_flat:
186+
# Identify vars in this group
187+
used_fusable_outvars = set(jax.tree.leaves(outvars_group))
188+
# Create mask for jaxpr_out inputs corresponding to this group
189+
in_used_mask = [
190+
True if v in used_fusable_outvars else False for v in jaxpr_out.invars
191+
]
192+
# Trace dependencies through jaxpr_out to find which final outputs are affected
193+
downstream_used_mask = _find_downstream(
194+
jaxpr_out, in_used_mask
195+
) # Mask for jaxpr_out.outvars (== jaxpr.outvars)
196+
197+
# Check for overlap in final output usage across groups
198+
for i, used in enumerate(downstream_used_mask):
199+
if used:
200+
if i in already_used_final_outputs:
201+
raise ValueError(
202+
"Outputs must be disjoint in order to use separate output fusions"
203+
)
204+
already_used_final_outputs.add(i)
205+
downstream_outputs_used_masks.append(downstream_used_mask)
206+
207+
# 4. Construct output permutation needed to restore original output order
208+
output_permutation = _construct_output_permutation(
209+
downstream_outputs_used_masks
210+
)
211+
212+
# Construct fusions for each group by DCEing the jaxpr_out
213+
output_fusions = []
214+
for i, outvars_group in enumerate(partial_flat):
215+
flat_group_vars, _ = tree_util.tree_flatten(outvars_group)
216+
downstream_used_mask = downstream_outputs_used_masks[i]
217+
218+
used_jaxpr_invars = [False] * len(all_values) + [
219+
v in flat_group_vars for v in jaxpr_out.invars
220+
]
221+
jaxpr_out_for_group, used_consts, _ = pe.dce_jaxpr_consts(
222+
jaxpr_out, downstream_used_mask, instantiate=used_jaxpr_invars
223+
)
224+
values_for_jaxpr = tuple(
225+
c for used, c in zip(used_consts, all_values, strict=True) if used
226+
)
227+
228+
def _fn(jaxpr, vals, *args, **kwargs):
229+
flat_args, _ = tree_util.tree_flatten((args, kwargs))
230+
out_flat = jax_core.eval_jaxpr(jaxpr, vals, *flat_args)
231+
return tuple(out_flat)
232+
233+
fn = functools.partial(_fn, jaxpr_out_for_group, values_for_jaxpr)
234+
in_type = jax.tree.map(
235+
lambda v: jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype), # pytype: disable=attribute-error
236+
outvars_group,
237+
)
238+
out_type = tuple(
239+
jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype) # pytype: disable=attribute-error
240+
for v in jaxpr_out_for_group.outvars
241+
)
242+
fusion = fusion_lib.Fusion(
243+
fn,
244+
(in_type, {}),
245+
out_type,
246+
)
247+
output_fusions.append(fusion)
248+
249+
return (
250+
tree_util.tree_unflatten(
251+
tree_util.tree_structure(output_fusion_prefix), output_fusions
252+
),
253+
output_permutation,
254+
)
255+
256+
114257
def fuse_jaxpr(
115258
jaxpr: jax_core.Jaxpr, out_tree: tree_util.PyTreeDef, consts, *args
116259
):
@@ -125,6 +268,15 @@ def fuse_jaxpr(
125268
raise ValueError("No fusable eqn found")
126269
fusion_eqn = jaxpr.eqns[fusion_eqn_index]
127270

271+
# Now let's check if we need to do any fusion at all, e.g. do the outputs of
272+
# the jaxpr have any dependence on the fusion at all? We can DCE the jaxpr
273+
# with all the inputs and outputs to check if there is a dependence.
274+
dced_jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars),
275+
instantiate=True)
276+
if not any(eqn.primitive is fusable_p for eqn in dced_jaxpr.eqns):
277+
# Short circuit if there is nothing to fuse.
278+
return jax_core.eval_jaxpr(dced_jaxpr, consts, *args)
279+
128280
candidate_values = [*consts, *args]
129281

130282
# Construct fusions for non-constant inputs to the fusable.
@@ -141,21 +293,20 @@ def fuse_jaxpr(
141293
in_fusions = tree_util.tree_unflatten(
142294
fusion_eqn.params["in_tree"], in_fusions_flat
143295
)
144-
out_fusion = construct_fusion(
296+
output_fusions, output_permutation = _construct_output_fusions(
145297
candidate_values,
146-
jaxpr.replace(
147-
eqns=jaxpr.eqns[:fusion_eqn_index]
148-
+ jaxpr.eqns[fusion_eqn_index + 1 :]
149-
),
150-
tree_util.tree_unflatten(out_tree, jaxpr.outvars),
151-
tree_util.tree_unflatten(
152-
fusion_eqn.params["out_tree"], fusion_eqn.outvars
153-
),
298+
jaxpr,
299+
out_tree,
300+
fusion_eqn_index,
301+
fusion_eqn.outvars,
302+
fusion_eqn.params["out_tree"],
303+
fusion_eqn.params["output_fusion_prefix"],
154304
)
155-
# Run the fusable.
156-
out = fusion_eqn.params["func"](*in_fusions, out_fusion)
157-
158-
# Now return the flattened output (the fuse_jaxpr caller should unflatten).
159-
out_flat = tree_util.tree_leaves(out)
160-
assert len(out_flat) == len(jaxpr.outvars)
161-
return out_flat
305+
out = fusion_eqn.params["func"](*in_fusions, output_fusions)
306+
flat_out = jax.tree.leaves(out)
307+
permuted_out = [flat_out[i] for i in output_permutation]
308+
assert len(permuted_out) == len(jaxpr.outvars), (
309+
len(permuted_out),
310+
len(jaxpr.outvars),
311+
)
312+
return permuted_out

tests/pallas/BUILD

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,27 @@ jax_multiplatform_test(
680680
] + py_deps("absl/testing") + py_deps("numpy"),
681681
)
682682

683+
jax_multiplatform_test(
684+
name = "fusion_test",
685+
srcs = [
686+
"fusion_test.py",
687+
],
688+
disable_configs = [
689+
"cpu",
690+
"cpu_shardy",
691+
],
692+
enable_backends = ["cpu"],
693+
tags = [
694+
"noasan",
695+
"nomsan",
696+
"notsan",
697+
],
698+
deps = [
699+
"//jax:pallas",
700+
"//jax:pallas_fuser",
701+
] + py_deps("absl/testing") + py_deps("numpy"),
702+
)
703+
683704
jax_multiplatform_test(
684705
name = "tpu_fusable_matmul_test",
685706
srcs = ["tpu_fusable_matmul_test.py"],

0 commit comments

Comments
 (0)