1414
1515"""Fuses a function."""
1616
17+ from collections .abc import Sequence
18+ import functools
1719from typing import Any
18-
1920import jax
2021from jax ._src import api_util
2122from jax ._src import core as jax_core
2223from jax ._src import linear_util as lu
2324from jax ._src import tree_util
2425from jax ._src .interpreters import partial_eval as pe
25-
2626from jax ._src .pallas .fuser import fusable_dtype
2727from jax ._src .pallas .fuser import fusion as fusion_lib
2828from jax ._src .pallas .fuser .fusable import fusable_p
@@ -73,9 +73,9 @@ def wrapper(*args, **kwargs):
7373_fusable : dict [jax_core .Primitive , Any ] = {}
7474
7575
76- def construct_fusion (
76+ def _construct_fusion_jaxpr (
7777 candidate_values , jaxpr : jax_core .Jaxpr , outvars , * invars , ** kwargs
78- ) -> fusion_lib . Fusion :
78+ ):
7979 flat_outvars , out_tree = tree_util .tree_flatten (outvars )
8080 flat_invars , in_tree = tree_util .tree_flatten ((invars , kwargs ))
8181 new_jaxpr_no_dce = jaxpr .replace (
@@ -94,12 +94,6 @@ def construct_fusion(
9494 c for used , c in zip (used_consts , candidate_values , strict = True ) if used
9595 )
9696 kernel_in_tree = tree_util .tree_structure ((invars , kwargs ))
97-
98- def _fn (* args , ** kwargs ):
99- flat_args , _ = tree_util .tree_flatten ((args , kwargs ))
100- out_flat = jax_core .eval_jaxpr (new_jaxpr , new_values , * flat_args )
101- return tree_util .tree_unflatten (out_tree , out_flat )
102-
10397 flat_in_type = [
10498 jax .ShapeDtypeStruct (x .aval .shape , x .aval .dtype ) for x in flat_invars
10599 ]
@@ -108,9 +102,158 @@ def _fn(*args, **kwargs):
108102 out_tree ,
109103 [jax .ShapeDtypeStruct (x .aval .shape , x .aval .dtype ) for x in flat_outvars ],
110104 )
105+ return new_jaxpr , new_values , in_type , out_type , out_tree
106+
107+
108+ def construct_fusion (
109+ candidate_values , jaxpr : jax_core .Jaxpr , outvars , * invars , ** kwargs
110+ ) -> fusion_lib .Fusion :
111+ new_jaxpr , new_values , in_type , out_type , out_tree = _construct_fusion_jaxpr (
112+ candidate_values , jaxpr , outvars , * invars , ** kwargs
113+ )
114+
115+ def _fn (* args , ** kwargs ):
116+ flat_args , _ = tree_util .tree_flatten ((args , kwargs ))
117+ out_flat = jax_core .eval_jaxpr (new_jaxpr , new_values , * flat_args )
118+ return tree_util .tree_unflatten (out_tree , out_flat )
119+
111120 return fusion_lib .Fusion (_fn , in_type , out_type )
112121
113122
123+ def _find_downstream (
124+ jaxpr : jax_core .Jaxpr , in_used : Sequence [bool ]
125+ ) -> tuple [bool , ...]:
126+ # TODO(sharadmv): We use partial_eval to query downstream dependencies which
127+ # is not an officially sanctioned way to do so, since PE is really used for
128+ # AD. In the future, we should have a special Jaxpr API that queries this.
129+ _ , _ , out_used , * _ = pe .partial_eval_jaxpr_custom (
130+ jaxpr ,
131+ in_unknowns = in_used ,
132+ in_inst = in_used ,
133+ ensure_out_unknowns = False ,
134+ ensure_out_inst = False ,
135+ saveable = lambda * _ , ** __ : False ,
136+ )
137+ return tuple (out_used )
138+
139+
140+ def _construct_output_permutation (
141+ used : list [tuple [bool , ...]],
142+ ) -> list [int ]:
143+ order = []
144+ for u in used :
145+ true_vals = [i for i in range (len (u )) if u [i ]]
146+ order .extend (true_vals )
147+ return [order .index (i ) for i in range (len (order ))]
148+
149+
150+ def _construct_output_fusions (
151+ candidate_values ,
152+ jaxpr ,
153+ out_tree ,
154+ fusion_eqn_index ,
155+ fusion_eqn_outvars , # Flat list of vars output by the fusable eqn
156+ fusion_eqn_out_tree , # Tree structure of the fusable eqn outputs
157+ output_fusion_prefix , # Pytree defining output groups
158+ ):
159+ # 1. Create jaxpr_out: represents computation *after* the fusable
160+ # Inputs: fusion_eqn_outvars
161+ # Outputs: jaxpr.outvars
162+ jaxpr_out , all_values , _ , _ , _ = _construct_fusion_jaxpr (
163+ candidate_values ,
164+ jaxpr .replace (
165+ eqns = jaxpr .eqns [:fusion_eqn_index ]
166+ + jaxpr .eqns [fusion_eqn_index + 1 :]
167+ ),
168+ tree_util .tree_unflatten (out_tree , jaxpr .outvars ), # Original outputs
169+ tree_util .tree_unflatten (
170+ fusion_eqn_out_tree , fusion_eqn_outvars
171+ ), # Fusable outputs as inputs
172+ )
173+
174+ # 2. Group fusable outputs based on the mask
175+ unflat_fusable_outvars = jax .tree .unflatten (
176+ fusion_eqn_out_tree , fusion_eqn_outvars
177+ )
178+ partial_flat = jax .tree .structure (output_fusion_prefix ).flatten_up_to (
179+ unflat_fusable_outvars
180+ )
181+
182+ # 3. Calculate dependencies and check disjointness
183+ downstream_outputs_used_masks = [] # List of bool tuples, one per group
184+ already_used_final_outputs = set () # Indices of final outputs already claimed
185+ for outvars_group in partial_flat :
186+ # Identify vars in this group
187+ used_fusable_outvars = set (jax .tree .leaves (outvars_group ))
188+ # Create mask for jaxpr_out inputs corresponding to this group
189+ in_used_mask = [
190+ True if v in used_fusable_outvars else False for v in jaxpr_out .invars
191+ ]
192+ # Trace dependencies through jaxpr_out to find which final outputs are affected
193+ downstream_used_mask = _find_downstream (
194+ jaxpr_out , in_used_mask
195+ ) # Mask for jaxpr_out.outvars (== jaxpr.outvars)
196+
197+ # Check for overlap in final output usage across groups
198+ for i , used in enumerate (downstream_used_mask ):
199+ if used :
200+ if i in already_used_final_outputs :
201+ raise ValueError (
202+ "Outputs must be disjoint in order to use separate output fusions"
203+ )
204+ already_used_final_outputs .add (i )
205+ downstream_outputs_used_masks .append (downstream_used_mask )
206+
207+ # 4. Construct output permutation needed to restore original output order
208+ output_permutation = _construct_output_permutation (
209+ downstream_outputs_used_masks
210+ )
211+
212+ # Construct fusions for each group by DCEing the jaxpr_out
213+ output_fusions = []
214+ for i , outvars_group in enumerate (partial_flat ):
215+ flat_group_vars , _ = tree_util .tree_flatten (outvars_group )
216+ downstream_used_mask = downstream_outputs_used_masks [i ]
217+
218+ used_jaxpr_invars = [False ] * len (all_values ) + [
219+ v in flat_group_vars for v in jaxpr_out .invars
220+ ]
221+ jaxpr_out_for_group , used_consts , _ = pe .dce_jaxpr_consts (
222+ jaxpr_out , downstream_used_mask , instantiate = used_jaxpr_invars
223+ )
224+ values_for_jaxpr = tuple (
225+ c for used , c in zip (used_consts , all_values , strict = True ) if used
226+ )
227+
228+ def _fn (jaxpr , vals , * args , ** kwargs ):
229+ flat_args , _ = tree_util .tree_flatten ((args , kwargs ))
230+ out_flat = jax_core .eval_jaxpr (jaxpr , vals , * flat_args )
231+ return tuple (out_flat )
232+
233+ fn = functools .partial (_fn , jaxpr_out_for_group , values_for_jaxpr )
234+ in_type = jax .tree .map (
235+ lambda v : jax .ShapeDtypeStruct (v .aval .shape , v .aval .dtype ), # pytype: disable=attribute-error
236+ outvars_group ,
237+ )
238+ out_type = tuple (
239+ jax .ShapeDtypeStruct (v .aval .shape , v .aval .dtype ) # pytype: disable=attribute-error
240+ for v in jaxpr_out_for_group .outvars
241+ )
242+ fusion = fusion_lib .Fusion (
243+ fn ,
244+ (in_type , {}),
245+ out_type ,
246+ )
247+ output_fusions .append (fusion )
248+
249+ return (
250+ tree_util .tree_unflatten (
251+ tree_util .tree_structure (output_fusion_prefix ), output_fusions
252+ ),
253+ output_permutation ,
254+ )
255+
256+
114257def fuse_jaxpr (
115258 jaxpr : jax_core .Jaxpr , out_tree : tree_util .PyTreeDef , consts , * args
116259):
@@ -125,6 +268,15 @@ def fuse_jaxpr(
125268 raise ValueError ("No fusable eqn found" )
126269 fusion_eqn = jaxpr .eqns [fusion_eqn_index ]
127270
271+ # Now let's check if we need to do any fusion at all, e.g. do the outputs of
272+ # the jaxpr have any dependence on the fusion at all? We can DCE the jaxpr
273+ # with all the inputs and outputs to check if there is a dependence.
274+ dced_jaxpr , _ = pe .dce_jaxpr (jaxpr , [True ] * len (jaxpr .outvars ),
275+ instantiate = True )
276+ if not any (eqn .primitive is fusable_p for eqn in dced_jaxpr .eqns ):
277+ # Short circuit if there is nothing to fuse.
278+ return jax_core .eval_jaxpr (dced_jaxpr , consts , * args )
279+
128280 candidate_values = [* consts , * args ]
129281
130282 # Construct fusions for non-constant inputs to the fusable.
@@ -141,21 +293,20 @@ def fuse_jaxpr(
141293 in_fusions = tree_util .tree_unflatten (
142294 fusion_eqn .params ["in_tree" ], in_fusions_flat
143295 )
144- out_fusion = construct_fusion (
296+ output_fusions , output_permutation = _construct_output_fusions (
145297 candidate_values ,
146- jaxpr .replace (
147- eqns = jaxpr .eqns [:fusion_eqn_index ]
148- + jaxpr .eqns [fusion_eqn_index + 1 :]
149- ),
150- tree_util .tree_unflatten (out_tree , jaxpr .outvars ),
151- tree_util .tree_unflatten (
152- fusion_eqn .params ["out_tree" ], fusion_eqn .outvars
153- ),
298+ jaxpr ,
299+ out_tree ,
300+ fusion_eqn_index ,
301+ fusion_eqn .outvars ,
302+ fusion_eqn .params ["out_tree" ],
303+ fusion_eqn .params ["output_fusion_prefix" ],
154304 )
155- # Run the fusable.
156- out = fusion_eqn .params ["func" ](* in_fusions , out_fusion )
157-
158- # Now return the flattened output (the fuse_jaxpr caller should unflatten).
159- out_flat = tree_util .tree_leaves (out )
160- assert len (out_flat ) == len (jaxpr .outvars )
161- return out_flat
305+ out = fusion_eqn .params ["func" ](* in_fusions , output_fusions )
306+ flat_out = jax .tree .leaves (out )
307+ permuted_out = [flat_out [i ] for i in output_permutation ]
308+ assert len (permuted_out ) == len (jaxpr .outvars ), (
309+ len (permuted_out ),
310+ len (jaxpr .outvars ),
311+ )
312+ return permuted_out
0 commit comments