Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,17 +1813,17 @@ def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params):


def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_):
# jaxpr input effects are indexed to include jaxpr.constvars, but the pjit eqn
# should have effects indexed only on its explicit arguments
if jaxpr.constvars:
effs = {e.replace(input_index=e.input_index - len(jaxpr.constvars))
if isinstance(e, effects.JaxprInputEffect)
else e for e in jaxpr.effects}
else:
effs = jaxpr.effects
effs = _pjit_eqn_effects(jaxpr) if jaxpr.constvars else jaxpr.effects
return jaxpr.out_avals, effs
jit_p.def_effectful_abstract_eval(_pjit_abstract_eval)

def _pjit_eqn_effects(jaxpr):
# jaxpr input effects are indexed to include jaxpr.constvars, but the pjit eqn
# should have effects indexed only on its explicit arguments
effs = jaxpr.effects
return {e.replace(input_index=e.input_index - len(jaxpr.constvars))
if isinstance(e, effects.JaxprInputEffect) else e for e in effs}


def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext,
name: str, jaxpr: core.ClosedJaxpr,
Expand Down Expand Up @@ -2491,10 +2491,11 @@ def keep_where(xs, keeps):
if not any(used_inputs) and not any(used_outputs) and not dced_jaxpr.effects:
return used_inputs, None
else:
new_effs = _pjit_eqn_effects(dced_jaxpr)
new_eqn = core.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, dced_jaxpr.effects, eqn.source_info, eqn.ctx)
eqn.primitive, new_params, new_effs, eqn.source_info, eqn.ctx)
return used_inputs, new_eqn

pe.dce_rules[jit_p] = dce_jaxpr_pjit_rule
Expand Down
36 changes: 36 additions & 0 deletions tests/jaxpr_effects_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,42 @@ def foo(w):

foo(jax.new_ref(jnp.eye(1))) # don't crash

def test_cond_const_input_effect_indexing(self):
@jax.custom_jvp
def weird(x):
return x

@weird.defjvp
def weird_jvp(primals, tangents):
(x,), (xdot,) = primals, tangents
return jnp.sum(np.ones(3)) * x, xdot

@jax.jit
def f(x):
x_ref = jax.new_ref(0.)
return jax.lax.cond(x < 0, lambda: x_ref[...], lambda: weird(x[...]))

jax.jvp(f, (1.,), (1.,))

def test_scan_const_input_effect_indexing(self):
@jax.custom_jvp
def weird(x):
return x

@weird.defjvp
def weird_jvp(primals, tangents):
(x,), (xdot,) = primals, tangents
return jnp.sum(np.ones(3)) * x, xdot

@jax.jit
def f(x):
x_ref = jax.new_ref(0.)
y, () = jax.lax.scan(lambda _, __: (weird(x_ref[...]), ()),
x_ref[...], length=1)
return y

jax.jvp(f, (1.,), (1.,))


@jtu.thread_unsafe_test_class() # because of mlir.register_lowering calls
class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
Expand Down
Loading