Skip to content

Commit e949eff

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas/Fuser] DCE fusion jaxprs before pulling (to avoid unnecessary computations being staged out in block functions)
PiperOrigin-RevId: 738218113
1 parent 4d71575 commit e949eff

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

jax/_src/pallas/fuser/block_spec.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,13 @@ def wrapped(*args, **kwargs):
244244
_unwrap_block_spec_scalar_prefetch, out_block_specs
245245
)
246246
flat_block_specs, out_tree = jax.tree.flatten(block_specs_)
247+
jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts(
248+
jaxpr,
249+
used_outputs=[True] * len(jaxpr.outvars),
250+
instantiate=True,
251+
)
252+
assert all(used_invars)
253+
assert all(used_consts)
247254
in_block_specs, env, read_usage_env = _pull_block_spec(
248255
jaxpr,
249256
tuple(flat_block_specs),

0 commit comments

Comments
 (0)