diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 16d301b74abd..a59c0cdeeeb3 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1813,17 +1813,17 @@ def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params): def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_): - # jaxpr input effects are indexed to include jaxpr.constvars, but the pjit eqn - # should have effects indexed only on its explicit arguments - if jaxpr.constvars: - effs = {e.replace(input_index=e.input_index - len(jaxpr.constvars)) - if isinstance(e, effects.JaxprInputEffect) - else e for e in jaxpr.effects} - else: - effs = jaxpr.effects + effs = _pjit_eqn_effects(jaxpr) if jaxpr.constvars else jaxpr.effects return jaxpr.out_avals, effs jit_p.def_effectful_abstract_eval(_pjit_abstract_eval) +def _pjit_eqn_effects(jaxpr): + # jaxpr input effects are indexed to include jaxpr.constvars, but the pjit eqn + # should have effects indexed only on its explicit arguments + effs = jaxpr.effects + return {e.replace(input_index=e.input_index - len(jaxpr.constvars)) + if isinstance(e, effects.JaxprInputEffect) else e for e in effs} + def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext, name: str, jaxpr: core.ClosedJaxpr, @@ -2491,10 +2491,11 @@ def keep_where(xs, keeps): if not any(used_inputs) and not any(used_outputs) and not dced_jaxpr.effects: return used_inputs, None else: + new_effs = _pjit_eqn_effects(dced_jaxpr) new_eqn = core.new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [v for v, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, dced_jaxpr.effects, eqn.source_info, eqn.ctx) + eqn.primitive, new_params, new_effs, eqn.source_info, eqn.ctx) return used_inputs, new_eqn pe.dce_rules[jit_p] = dce_jaxpr_pjit_rule diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 997ba86c16bd..67120364ee28 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -276,6 +276,42 @@ def foo(w): foo(jax.new_ref(jnp.eye(1))) # don't crash + def test_cond_const_input_effect_indexing(self): + @jax.custom_jvp + def weird(x): + return x + + @weird.defjvp + def weird_jvp(primals, tangents): + (x,), (xdot,) = primals, tangents + return jnp.sum(np.ones(3)) * x, xdot + + @jax.jit + def f(x): + x_ref = jax.new_ref(0.) + return jax.lax.cond(x < 0, lambda: x_ref[...], lambda: weird(x[...])) + + jax.jvp(f, (1.,), (1.,)) + + def test_scan_const_input_effect_indexing(self): + @jax.custom_jvp + def weird(x): + return x + + @weird.defjvp + def weird_jvp(primals, tangents): + (x,), (xdot,) = primals, tangents + return jnp.sum(np.ones(3)) * x, xdot + + @jax.jit + def f(x): + x_ref = jax.new_ref(0.) + y, () = jax.lax.scan(lambda _, __: (weird(x_ref[...]), ()), + x_ref[...], length=1) + return y + + jax.jvp(f, (1.,), (1.,)) + @jtu.thread_unsafe_test_class() # because of mlir.register_lowering calls class EffectfulJaxprLoweringTest(jtu.JaxTestCase):