2626import jax
2727from jax import lax
2828from jax ._src import ad_util
29- from jax ._src import api_util
3029from jax ._src import core
3130from jax ._src import custom_derivatives
32- from jax ._src import linear_util as lu
3331from jax ._src import pjit
3432from jax ._src import tree_util
3533from jax ._src import util
3634from jax ._src .interpreters import partial_eval as pe
3735from jax ._src .pallas import core as pallas_core
36+ from jax ._src .pallas .fuser import fuser_utils
3837import jax .numpy as jnp
3938import numpy as np
4039
@@ -226,18 +225,6 @@ def new_index_map(*args):
226225 return out_block_spec
227226
228227
229- def _make_jaxpr (f , * args , ** kwargs ):
230- flat_args , in_tree = tree_util .tree_flatten ((args , kwargs ))
231- flat_avals = [core .get_aval (x ) for x in flat_args ]
232- debug_info = api_util .debug_info ('make_jaxpr' , f , args , kwargs )
233- flat_fun , out_tree_thunk = api_util .flatten_fun (
234- lu .wrap_init (f , debug_info = debug_info ), in_tree
235- )
236- jaxpr , _ , consts , _ = pe .trace_to_jaxpr_dynamic (flat_fun , flat_avals )
237- out_tree = out_tree_thunk ()
238- return jaxpr , consts , in_tree , out_tree
239-
240-
241228def pull_block_spec (
242229 f : Callable ,
243230 out_block_specs : pallas_core .BlockSpec | tuple [pallas_core .BlockSpec , ...],
@@ -246,7 +233,9 @@ def pull_block_spec(
246233 grid : tuple [int | jax .Array , ...] | None = None ,
247234):
248235 def wrapped (* args , ** kwargs ):
249- jaxpr , consts , in_tree , out_tree_ = _make_jaxpr (f , * args , ** kwargs )
236+ jaxpr , consts , in_tree , out_tree_ = fuser_utils .make_jaxpr (
237+ f , * args , ** kwargs
238+ )
250239 # TODO(sharadmv): handle these consts better, they should correspond to
251240 # scalar prefetch.
252241 del consts , out_tree_
@@ -563,7 +552,9 @@ def write_env(var, val):
563552def get_fusion_values (
564553 fusion : Callable , * args , ** kwargs
565554) -> tuple [Callable , tuple [jax .Array , ...], tuple [jax .Array , ...]]:
566- jaxpr , values , in_tree , out_tree = _make_jaxpr (fusion , * args , ** kwargs )
555+ jaxpr , values , in_tree , out_tree = fuser_utils .make_jaxpr (
556+ fusion , * args , ** kwargs
557+ )
567558 assert len (values ) == len (jaxpr .constvars ), (jaxpr , values )
568559 out_usages = tuple ({Usage .REGULAR } for _ in jaxpr .outvars )
569560 read_usage_env = compute_usage (jaxpr , out_usages )
@@ -1325,7 +1316,7 @@ def wrapper(*args, **kwargs):
13251316 flat_block_specs , in_tree_ = tree_util .tree_flatten (
13261317 (in_spec_args , in_spec_kwargs )
13271318 )
1328- jaxpr , _ , in_tree , out_tree = _make_jaxpr (f , * args , ** kwargs )
1319+ jaxpr , _ , in_tree , out_tree = fuser_utils . make_jaxpr (f , * args , ** kwargs )
13291320 if in_tree != in_tree_ :
13301321 raise ValueError (f'Expected { in_tree } PyTree, got { in_tree_ } ' )
13311322 out_bs = _push_block_spec_jaxpr (jaxpr , * flat_block_specs )
0 commit comments