diff --git a/jax/_src/core.py b/jax/_src/core.py index 2dded18dacbc..fd108f372906 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -3462,19 +3462,7 @@ def write(v: Var, a: AvalQDD) -> None: f"Equation effects: {eqn.effects}. " f"Inferred effects: {eqn_effects}") for eff in eqn.effects: - if isinstance(eff, effects.JaxprInputEffect): - eqn_invar = eqn.invars[eff.input_index] - if type(eqn_invar) is Literal or eqn_invar in mut_arrays: - continue - if (jaxpr_index := in_idx.get(eqn_invar, sentinel)) is sentinel: - raise JaxprTypeError( - "Invalid `JaxprInputEffect`: must correspond to a jaxpr invar") - jaxpr_effect = eff.replace(input_index=jaxpr_index) - if jaxpr_effect not in jaxpr.effects: - raise JaxprTypeError( - "Invalid `JaxprInputEffect`: must be present in jaxpr. " - f"{jaxpr_effect} is not in {jaxpr.effects}.") - elif isinstance(eff, NamedAxisEffect): + if isinstance(eff, NamedAxisEffect): # It is valid for a primitive to discharge the named axis effect. continue elif eff not in jaxpr.effects: @@ -3604,13 +3592,7 @@ def substitute(aval: AbstractValue): if type(d) is Var else d for d in a.shape)) if type(a) is DShapedArray else a for a in out_avals] - # jaxpr input effects are indexed to include jaxpr.constvars, but the eqn - # should have effects indexed only on its explicit arguments - effs = {e.replace(input_index=e.input_index - len(call_jaxpr.constvars)) - if isinstance(e, effects.JaxprInputEffect) - else e for e in call_jaxpr.effects} - - return out_type, effs + return out_type, call_jaxpr.effects def _check_map(ctx_factory, prim, in_avals, params): if "call_jaxpr" not in params: diff --git a/jax/_src/effects.py b/jax/_src/effects.py index efbf10638cf0..b840c934af7a 100644 --- a/jax/_src/effects.py +++ b/jax/_src/effects.py @@ -53,7 +53,6 @@ from __future__ import annotations from collections.abc import Iterable, Set -from typing import Any class Effect: @@ -66,29 +65,18 @@ class JaxprInputEffect(Effect): This is used as a base class for effects associated with inputs, e.g., reading/writing from mutable inputs. - - When used in a `JaxprEqn`, `input_index` refers to `eqn.invars`. - When used in a `Jaxpr`, `input_index` refers to `jaxpr.constvars + jaxpr.invars`. """ - def __init__(self, input_index: Any): - self.input_index = input_index - - def replace(self, *, input_index: Any | None = None): - if input_index is None: - input_index = self.input_index - return self.__class__(input_index) - def __eq__(self, other): if not isinstance(other, JaxprInputEffect): return NotImplemented - return self.input_index == other.input_index + return True def __hash__(self): - return hash((self.__class__, self.input_index)) + return hash(self.__class__) def __repr__(self): - return f"{self.__class__.__name__}({self.input_index})" + return f"{self.__class__.__name__}()" class EffectTypeSet: diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 7e5aacba3d80..9a731f3449ba 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -1473,11 +1473,8 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_ new_debug_info = jaxpr.jaxpr.debug_info._replace( arg_names=new_arg_names, result_paths=new_result_paths) constvars = jaxpr.jaxpr.constvars - new_effects = pe._renumber_effects( - (*constvars, *new_invars), (*constvars, *jaxpr.jaxpr.invars), - jaxpr.jaxpr.effects) new_jaxpr = core.Jaxpr(constvars, new_invars, new_outvars, jaxpr.jaxpr.eqns, - new_effects, new_debug_info) + jaxpr.jaxpr.effects, new_debug_info) return core.ClosedJaxpr(new_jaxpr, jaxpr.consts) def _perm(primal_counts: Sequence[int], tangent_counts: Sequence[int], diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 0101c45a3331..5c6c1515e820 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1601,11 +1601,7 @@ def _move_invars_right(jaxpr: ClosedJaxpr, to_move: tuple[bool, ...]): invars, rest = split_list(jaxpr.jaxpr.invars, [len(to_move)]) left_invars, right_invars = partition_list(to_move, invars) new_invars = [*left_invars, *right_invars, *rest] - new_effs = _renumber_effects( - (*jaxpr.jaxpr.constvars, *new_invars), - (*jaxpr.jaxpr.constvars, *jaxpr.jaxpr.invars), - jaxpr.jaxpr.effects) - new_jaxpr = jaxpr.jaxpr.replace(invars=new_invars, effects=new_effs) + new_jaxpr = jaxpr.jaxpr.replace(invars=new_invars) return jaxpr.replace(jaxpr=new_jaxpr) def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] @@ -1619,23 +1615,15 @@ def _move_binders_to_front(jaxpr: ClosedJaxpr, to_move: tuple[bool, ...] assert len(jaxpr.in_avals) == len(to_move) constvars, invars = jaxpr.jaxpr.constvars, jaxpr.jaxpr.invars new_invars = _move_to_front(invars, to_move) - new_effs = _renumber_effects( - (*constvars, *new_invars), (*constvars, *invars), jaxpr.jaxpr.effects) if jaxpr.jaxpr.debug_info.arg_names is None: new_arg_names = None else: new_arg_names = tuple(_move_to_front(jaxpr.jaxpr.debug_info.arg_names, to_move)) dbg = jaxpr.jaxpr.debug_info._replace(arg_names=new_arg_names) new_jaxpr = jaxpr.jaxpr.replace( - constvars=constvars, invars=new_invars, effects=new_effs, debug_info=dbg) + constvars=constvars, invars=new_invars, debug_info=dbg) return core.ClosedJaxpr(new_jaxpr, jaxpr.consts) -def _renumber_effects(new_vars, old_vars, effs): - newvar_idxs = {id(v): i for i, v in enumerate(new_vars)} - old_to_new = {i: newvar_idxs[id(v)] for i, v in enumerate(old_vars)} - return {e.replace(input_index=old_to_new[e.input_index]) - if isinstance(e, effects.JaxprInputEffect) else e for e in effs} - def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence: return ([elt for elt, move in zip(lst, to_move) if move] + [elt for elt, move in zip(lst, to_move) if not move]) @@ -1763,35 +1751,6 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: all_vars[outvar] = None # type: ignore mut_arrays.add(outvar) for eff in eqn.effects: - if isinstance(eff, effects.JaxprInputEffect): - if eff.input_index >= len(eqn.invars): - # TODO(mattjj): ask for forgiveness - dbg = type('Fake', (), {'resolve_result_paths': lambda self_: self_, - 'assert_arg_names': lambda _, __: None, - 'assert_result_paths': lambda _, __: None, - })() - raise ValueError( - f"`JaxprInputEffect` {eff} is invalid." - f"\n Equation: {eqn}\n" - "\n Jaxpr: " - f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore - eqn_invar = eqn.invars[eff.input_index] - if type(eqn_invar) is core.Literal or eqn_invar in mut_arrays: - continue - if (input_index := all_vars.get(eqn_invar, sentinel)) is sentinel: - # TODO(mattjj): ask for forgiveness - dbg = type('Fake', (), {'resolve_result_paths': lambda self_: self_, - 'assert_arg_names': lambda _, __: None, - 'assert_result_paths': lambda _, __: None, - })() - raise ValueError( - f"`JaxprInputEffect` {eff} does not have " - f"corresponding jaxpr input: {eqn_invar=}." - f"\n Equation: {eqn}\n" - f"\n Effects: {eqn.effects}\n" - "\n Jaxpr: " - f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore - eff = eff.replace(input_index=input_index) jaxpr_effects.add(eff) return jaxpr_effects diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index 173297e4d8a1..99d8a398efd0 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -96,9 +96,7 @@ def _pad_constvars(jaxpr: core.Jaxpr, left: tuple[core.AvalQDD, ...], def make_var(aq): return core.Var(aq.aval, initial_qdd=aq.qdd, final_qdd=aq.qdd) constvars = [*map(make_var, left), *jaxpr.constvars, *map(make_var, right)] - effs = pe._renumber_effects([*constvars, *jaxpr.invars], - [*jaxpr.constvars, *jaxpr.invars], jaxpr.effects) - jaxpr = jaxpr.replace(constvars=constvars, effects=effs) + jaxpr = jaxpr.replace(constvars=constvars) config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr @@ -112,11 +110,7 @@ def _dedup_consts(jaxpr, const_ids): outvars = [canonicalize.get(x, x) if isinstance(x, core.Var) else x for x in jaxpr.outvars] constvars = list(newvars.values()) - effs = pe._renumber_effects( - [*constvars, *jaxpr.invars], - [*map(canonicalize.get, jaxpr.constvars), *jaxpr.invars], jaxpr.effects) - jaxpr = jaxpr.replace(constvars=constvars, eqns=eqns, outvars=outvars, - effects=effs) + jaxpr = jaxpr.replace(constvars=constvars, eqns=eqns, outvars=outvars) config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 808545622608..0b9dfdce2d54 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -445,9 +445,6 @@ def _join_cond_effects(branches: Sequence[core.ClosedJaxpr]) -> effects.Effects: joined_effects = set() for b in branches: for eff in b.effects: - if isinstance(eff, effects.JaxprInputEffect): - # Offset index to handle predicate - eff = eff.replace(input_index=eff.input_index + 1) joined_effects.add(eff) return joined_effects diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 84a6b2d0bbfe..69a5781488c8 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -919,13 +919,7 @@ def _rearrange_mutable_binders( immut_names, mut_names = partition_list(is_mutable, names) new_arg_names = [*fst, *mut_names, *immut_names, *rst] dbg = jaxpr.jaxpr.debug_info._replace(arg_names=new_arg_names) - - # TODO(mattjj): don't we need to re-number effects? test coverage? - new_effs = pe._renumber_effects((*jaxpr.jaxpr.constvars, *new_invars), - (*jaxpr.jaxpr.constvars, *jaxpr.jaxpr.invars), - jaxpr.jaxpr.effects) - new_jaxpr = jaxpr.jaxpr.replace(invars=new_invars, effects=new_effs, - debug_info=dbg) + new_jaxpr = jaxpr.jaxpr.replace(invars=new_invars, debug_info=dbg) if config.enable_checks.value: core.check_jaxpr(new_jaxpr) return ClosedJaxpr(new_jaxpr, jaxpr.consts) @@ -1777,16 +1771,8 @@ def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts ) -> effects.Effects: joined_effects = set() for eff in cond_jaxpr.effects: - if isinstance(eff, effects.JaxprInputEffect): - index = eff.input_index - if index >= cond_nconsts: - index += body_nconsts - eff = eff.replace(input_index=index) joined_effects.add(eff) for eff in body_jaxpr.effects: - if isinstance(eff, effects.JaxprInputEffect): - index = eff.input_index + cond_nconsts - eff = eff.replace(input_index=index) joined_effects.add(eff) return joined_effects @@ -2285,17 +2271,6 @@ def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args, [cond_nconsts, body_nconsts]) - # Check if the same Ref is written to in both cond and body. - cond_write_ids = {id(cond_consts_avals[effect.input_index]) - for effect in cond_jaxpr.effects if isinstance(effect, state.WriteEffect)} - cond_has_writes = len(cond_write_ids) > 0 - body_write_ids = {id(body_consts_avals[effect.input_index]) - for effect in body_jaxpr.effects if isinstance(effect, state.WriteEffect)} - write_to_both_ids = cond_write_ids & body_write_ids - if write_to_both_ids: - raise NotImplementedError( - "Cannot write to the same ref in both cond and body of while loop.") - cond_is_ref = [ isinstance(aval, state.AbstractRef) and should for aval, should in zip(cond_consts_avals, cond_consts_discharge) @@ -2315,13 +2290,6 @@ def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args, num_body_refs = sum(body_is_ref) num_remaining_body_consts = body_nconsts - num_body_refs num_out_body_consts = num_remaining_body_consts - if cond_has_writes: - # If the cond has writes, we need to add the cond consts into the body - # consts since we need to evaluate the cond condition in the body. - remaining_body_consts = [*remaining_cond_consts, *remaining_body_consts] - remaining_body_const_avals = [*remaining_cond_const_avals, - *remaining_body_const_avals] - num_remaining_body_consts += num_remaining_cond_consts num_carry = len(in_avals) - body_nconsts - cond_nconsts body_jaxpr, body_jaxpr_consts = body_jaxpr.jaxpr, body_jaxpr.consts @@ -2353,23 +2321,8 @@ def new_body(*consts_refs_carry): consts, body_refs, cond_refs, carry = split_list( consts_refs_carry, [num_remaining_body_consts, num_body_refs, num_cond_refs]) - if cond_has_writes: - # We run the cond jaxpr in the body so that Refs that are updated - # in the cond jaxpr are persisted via the carry. - cond_consts, body_consts = split_list(consts, [num_remaining_cond_consts]) - cond_consts_and_refs = merge_lists(cond_is_ref, cond_consts, cond_refs) - cond_carry_refs = core.eval_jaxpr(discharged_cond_jaxpr, (), - *cond_consts_and_refs, - *carry) - # Note: in order to handle the same Ref being updated in both the cond - # and body, we would need to interleave the updated cond_carry_refs into - # body_refs here. - # Currently we disallow this so we don't need to handle it. - _, cond_refs_out = split_list(cond_carry_refs, [1]) - assert len(cond_refs_out) == len(cond_refs) - else: - body_consts = consts - cond_refs_out = cond_refs + body_consts = consts + cond_refs_out = cond_refs body_consts_and_refs = merge_lists(body_is_ref, body_consts, body_refs) body_carry_refs = core.eval_jaxpr(discharged_body_jaxpr, (), diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 62b00483fb1e..79e60593deac 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1833,18 +1833,9 @@ def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params): def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_): - effs = _pjit_eqn_effects(jaxpr) if jaxpr.constvars else jaxpr.effects - return jaxpr.out_avals, effs + return jaxpr.out_avals, jaxpr.effects jit_p.def_effectful_abstract_eval(_pjit_abstract_eval) -def _pjit_eqn_effects(jaxpr): - # jaxpr input effects are indexed to include jaxpr.constvars, but the pjit eqn - # should have effects indexed only on its explicit arguments - effs = jaxpr.effects - return {e.replace(input_index=e.input_index - len(jaxpr.constvars)) - if isinstance(e, effects.JaxprInputEffect) else e for e in effs} - - def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext, name: str, jaxpr: core.ClosedJaxpr, num_const_args: int, in_avals, @@ -2512,7 +2503,7 @@ def keep_where(xs, keeps): if not any(used_inputs) and not any(used_outputs) and not dced_jaxpr.effects: return used_inputs, None else: - new_effs = _pjit_eqn_effects(dced_jaxpr) + new_effs = dced_jaxpr.effects new_eqn = core.new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [v for v, used in zip(eqn.outvars, used_outputs) if used], diff --git a/jax/_src/state/__init__.py b/jax/_src/state/__init__.py index adf7926d7dbd..7cc4260a8b11 100644 --- a/jax/_src/state/__init__.py +++ b/jax/_src/state/__init__.py @@ -21,7 +21,6 @@ Transform as Transform, TransformedRef as TransformedRef, WriteEffect as WriteEffect, - get_ref_state_effects as get_ref_state_effects, get_transforms_shape as get_transforms_shape, shaped_array_ref as shaped_array_ref, ) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index a20d5648217b..17bf7a47408f 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -685,39 +685,10 @@ def _run_state_abstract_eval(*avals: core.AbstractValue, jaxpr: core.Jaxpr, is_initialized: tuple[bool, ...]): del which_linear assert sum(is_initialized) == len(avals) - # When we abstractly evaluate `run_state`, we want to keep track of which - # input avals are `Ref`s and which are not. If an aval is a `Ref`, we want to - # "propagate" out its inner effects. Otherwise, the effects are local to this - # `run_state`. - inner_to_outer_aval_mapping = {} - outer_ref_index = 0 - for i, is_init in enumerate(is_initialized): - if not is_init: - pass - inner_to_outer_aval_mapping[i] = outer_ref_index - outer_ref_index += 1 - nonlocal_effects = set() - is_ref = {i for i, aval in enumerate(avals) if isinstance(aval, AbstractRef)} - for eff in jaxpr.effects: - if not isinstance(eff, RefEffect): - nonlocal_effects.add(eff) - continue - if eff.input_index not in inner_to_outer_aval_mapping: - # This means that this effect corresponds to an uninitialized Ref and - # should not propagate out of the primitive. - continue - # If we do propagate the effect, we need to update the input index to - # correspond to the outer index. - outer_index = inner_to_outer_aval_mapping[eff.input_index] - if outer_index in is_ref: - # This means that the effect corresponds to a Ref from an outside scope. - nonlocal_effects.add( - eff.replace(input_index=inner_to_outer_aval_mapping[eff.input_index]) - ) assert len(jaxpr.invars) == len(is_initialized) if not all(is_initialized): raise NotImplementedError # Uninitialized refs are not in avals. - return avals, nonlocal_effects + return avals, jaxpr.effects run_state_p.def_effectful_abstract_eval(_run_state_abstract_eval) def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *, diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 493ddc962b3b..dfca0f3cc01d 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -434,7 +434,7 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args, if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") out_aval = ref_aval.inner_aval - return (out_aval, {ReadEffect(0)}) + return (out_aval, {ReadEffect()}) get_p.def_effectful_abstract_eval(_get_abstract_eval) def _swap_abstract_eval(ref_aval: AbstractRef, @@ -471,7 +471,7 @@ def _swap_abstract_eval(ref_aval: AbstractRef, if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") out_aval = ref_aval.inner_aval - return (out_aval, {WriteEffect(0)}) + return (out_aval, {WriteEffect()}) swap_p.def_effectful_abstract_eval(_swap_abstract_eval) @@ -508,7 +508,7 @@ def _addupdate_abstract_eval(ref_aval: AbstractRef, # Check that the transforms are valid if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") - return [], {AccumEffect(0)} + return [], {AccumEffect()} addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval) ## Pretty printing for `get` and `swap` in jaxprs diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 2644f8392416..8332db347b6d 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -45,12 +45,10 @@ class RefEffect(effects.JaxprInputEffect): name: str def __eq__(self, other): - if not isinstance(other, self.__class__): - return False - return self.input_index == other.input_index + return isinstance(other, self.__class__) def __hash__(self): - return hash((self.__class__, self.input_index)) + return hash(self.__class__) def _pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: if isinstance(self.input_index, core.Var): @@ -572,13 +570,6 @@ def _unmap_ref(size, axis, explicit_mesh_axis, ref_aval): core.aval_mapping_handlers[AbstractRef] = (_map_ref, _unmap_ref) -def get_ref_state_effects( - avals: Sequence[core.AbstractValue], - effects: core.Effects) -> list[set[StateEffect]]: - return [{eff for eff in effects - if isinstance(eff, (ReadEffect, WriteEffect, AccumEffect)) - and eff.input_index == i} for i, _ in enumerate(avals)] - def shaped_array_ref( shape: tuple[int, ...], dtype, weak_type: bool = False) -> AbstractRef: return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type)) diff --git a/tests/state_test.py b/tests/state_test.py index 8441abb21e73..b6a66a65ed43 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -128,8 +128,7 @@ def f(x_ref): else: jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic( wrap_init(f, 1), [ref_aval]) - self.assertSetEqual(jaxpr.effects, - {ReadEffect(len(jaxpr.constvars))}) + self.assertSetEqual(jaxpr.effects, {ReadEffect()}) self.assertLen(out_avals, 1) out_aval, = out_avals self.assertIsInstance(out_aval, core.ShapedArray) @@ -225,8 +224,7 @@ def f(x_ref, val): else: jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic( wrap_init(f, 2), [ref_aval, val_aval]) - self.assertSetEqual(jaxpr.effects, - {WriteEffect(len(jaxpr.constvars))}) + self.assertSetEqual(jaxpr.effects, {WriteEffect()}) self.assertLen(out_avals, 1) out_aval, = out_avals self.assertIsInstance(out_aval, core.ShapedArray) @@ -302,8 +300,7 @@ def f(x_ref, val): else: jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic( wrap_init(f, 2), [ref_aval, val_aval]) - self.assertSetEqual(jaxpr.effects, - {AccumEffect(len(jaxpr.constvars))}) + self.assertSetEqual(jaxpr.effects, {AccumEffect()}) self.assertLen(out_avals, 0) def test_addupdate_abstract_eval_must_take_in_refs(self): @@ -738,8 +735,7 @@ def f(a_ref, b_ref): self.assertLen(discharged_jaxpr.outvars, 1) self.assertIsInstance(discharged_jaxpr.invars[0].aval, AbstractRef) self.assertIsInstance(discharged_jaxpr.invars[1].aval, core.ShapedArray) - self.assertEqual(discharged_jaxpr.effects, - {WriteEffect(len(discharged_jaxpr.constvars))}) + self.assertEqual(discharged_jaxpr.effects, {WriteEffect()}) def test_ellipsis_index(self): def f(ref): @@ -781,8 +777,8 @@ def body(i, st): f_jaxpr = jax.make_jaxpr(f)(ref(1.), ref(2.)) jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, True]) # Effects on y_ref were discharged away but not the effects on x_ref - self.assertEqual(f_jaxpr.effects, {ReadEffect(0), WriteEffect(0), ReadEffect(1), WriteEffect(1)}) - self.assertEqual(jaxpr.effects, {ReadEffect(0), WriteEffect(0)}) + self.assertEqual(f_jaxpr.effects, {ReadEffect(), WriteEffect(), ReadEffect(), WriteEffect()}) + self.assertEqual(jaxpr.effects, {ReadEffect(), WriteEffect()}) # x_ref arg is still a reference but y_ref is discharged self.assertNotIsInstance(jaxpr.invars[1].aval, AbstractRef) self.assertIsInstance(jaxpr.invars[0].aval, AbstractRef) @@ -1104,8 +1100,8 @@ def false_fun(): f_jaxpr = jax.make_jaxpr(f0)(False, ref(3.), ref(4.)) jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, False, True]) # Effects on y_ref were discharged away but not the effects on x_ref - self.assertEqual(f_jaxpr.effects, {ReadEffect(1), WriteEffect(1), ReadEffect(2), WriteEffect(2)}) - self.assertEqual(jaxpr.effects, {ReadEffect(1), WriteEffect(1)}) + self.assertEqual(f_jaxpr.effects, {ReadEffect(), WriteEffect(), ReadEffect(), WriteEffect()}) + self.assertEqual(jaxpr.effects, {ReadEffect(), WriteEffect()}) # x_ref arg is still a reference but y_ref is discharged self.assertNotIsInstance(jaxpr.invars[2].aval, AbstractRef) self.assertIsInstance(jaxpr.invars[1].aval, AbstractRef) @@ -1493,10 +1489,10 @@ def inner(y_ref): self.assertEmpty(jaxpr.effects) self.assertEmpty(jaxpr.jaxpr.eqns[0].effects) self.assertSetEqual(jaxpr.jaxpr.eqns[0].params["jaxpr"].effects, - {ReadEffect(0)}) + {ReadEffect()}) self.assertSetEqual( jaxpr.jaxpr.eqns[0].params["jaxpr"].eqns[0].params["jaxpr"].effects, - {ReadEffect(0), ReadEffect(1)}) + {ReadEffect(), ReadEffect()}) def test_jvp_of_run_state(self): @run_state