Skip to content

Commit 4d71575

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
Make sure to DCE read effects
PiperOrigin-RevId: 738215055
1 parent 8c7a55e commit 4d71575

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

jax/_src/interpreters/partial_eval.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
JaxprEqn, Primitive, ShapedArray, DShapedArray,
4242
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
4343
InputType, OutputType, get_referent, JaxprEqnContext)
44-
from jax._src.state.types import AbstractRef
44+
from jax._src.state.types import AbstractRef, ReadEffect
4545
from jax._src.tree_util import (PyTreeDef, treedef_tuple,
4646
tree_flatten, tree_structure)
4747
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
@@ -1423,7 +1423,8 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
14231423

14241424

14251425
def has_effects(eqn: JaxprEqn) -> bool:
1426-
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)}
1426+
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)
1427+
and not isinstance(e, ReadEffect)}
14271428
return bool(effs)
14281429

14291430

0 commit comments

Comments
 (0)