Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 2 additions & 20 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3462,19 +3462,7 @@ def write(v: Var, a: AvalQDD) -> None:
f"Equation effects: {eqn.effects}. "
f"Inferred effects: {eqn_effects}")
for eff in eqn.effects:
if isinstance(eff, effects.JaxprInputEffect):
eqn_invar = eqn.invars[eff.input_index]
if type(eqn_invar) is Literal or eqn_invar in mut_arrays:
continue
if (jaxpr_index := in_idx.get(eqn_invar, sentinel)) is sentinel:
raise JaxprTypeError(
"Invalid `JaxprInputEffect`: must correspond to a jaxpr invar")
jaxpr_effect = eff.replace(input_index=jaxpr_index)
if jaxpr_effect not in jaxpr.effects:
raise JaxprTypeError(
"Invalid `JaxprInputEffect`: must be present in jaxpr. "
f"{jaxpr_effect} is not in {jaxpr.effects}.")
elif isinstance(eff, NamedAxisEffect):
if isinstance(eff, NamedAxisEffect):
# It is valid for a primitive to discharge the named axis effect.
continue
elif eff not in jaxpr.effects:
Expand Down Expand Up @@ -3604,13 +3592,7 @@ def substitute(aval: AbstractValue):
if type(d) is Var else d for d in a.shape))
if type(a) is DShapedArray else a for a in out_avals]

# jaxpr input effects are indexed to include jaxpr.constvars, but the eqn
# should have effects indexed only on its explicit arguments
effs = {e.replace(input_index=e.input_index - len(call_jaxpr.constvars))
if isinstance(e, effects.JaxprInputEffect)
else e for e in call_jaxpr.effects}

return out_type, effs
return out_type, call_jaxpr.effects

def _check_map(ctx_factory, prim, in_avals, params):
if "call_jaxpr" not in params:
Expand Down
18 changes: 3 additions & 15 deletions jax/_src/effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from __future__ import annotations

from collections.abc import Iterable, Set
from typing import Any


class Effect:
Expand All @@ -66,29 +65,18 @@ class JaxprInputEffect(Effect):

This is used as a base class for effects associated with inputs, e.g.,
reading/writing from mutable inputs.

When used in a `JaxprEqn`, `input_index` refers to `eqn.invars`.
When used in a `Jaxpr`, `input_index` refers to `jaxpr.constvars + jaxpr.invars`.
"""

def __init__(self, input_index: Any):
self.input_index = input_index

def replace(self, *, input_index: Any | None = None):
if input_index is None:
input_index = self.input_index
return self.__class__(input_index)

def __eq__(self, other):
if not isinstance(other, JaxprInputEffect):
return NotImplemented
return self.input_index == other.input_index
return True

def __hash__(self):
return hash((self.__class__, self.input_index))
return hash(self.__class__)

def __repr__(self):
return f"{self.__class__.__name__}({self.input_index})"
return f"{self.__class__.__name__}()"

class EffectTypeSet:

Expand Down
5 changes: 1 addition & 4 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,11 +1473,8 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_
new_debug_info = jaxpr.jaxpr.debug_info._replace(
arg_names=new_arg_names, result_paths=new_result_paths)
constvars = jaxpr.jaxpr.constvars
new_effects = pe._renumber_effects(
(*constvars, *new_invars), (*constvars, *jaxpr.jaxpr.invars),
jaxpr.jaxpr.effects)
new_jaxpr = core.Jaxpr(constvars, new_invars, new_outvars, jaxpr.jaxpr.eqns,
new_effects, new_debug_info)
jaxpr.jaxpr.effects, new_debug_info)
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)

def _perm(primal_counts: Sequence[int], tangent_counts: Sequence[int],
Expand Down
45 changes: 2 additions & 43 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,11 +1601,7 @@ def _move_invars_right(jaxpr: ClosedJaxpr, to_move: tuple[bool, ...]):
invars, rest = split_list(jaxpr.jaxpr.invars, [len(to_move)])
left_invars, right_invars = partition_list(to_move, invars)
new_invars = [*left_invars, *right_invars, *rest]
new_effs = _renumber_effects(
(*jaxpr.jaxpr.constvars, *new_invars),
(*jaxpr.jaxpr.constvars, *jaxpr.jaxpr.invars),
jaxpr.jaxpr.effects)
new_jaxpr = jaxpr.jaxpr.replace(invars=new_invars, effects=new_effs)
new_jaxpr = jaxpr.jaxpr.replace(invars=new_invars)
return jaxpr.replace(jaxpr=new_jaxpr)

def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]
Expand All @@ -1619,23 +1615,15 @@ def _move_binders_to_front(jaxpr: ClosedJaxpr, to_move: tuple[bool, ...]
assert len(jaxpr.in_avals) == len(to_move)
constvars, invars = jaxpr.jaxpr.constvars, jaxpr.jaxpr.invars
new_invars = _move_to_front(invars, to_move)
new_effs = _renumber_effects(
(*constvars, *new_invars), (*constvars, *invars), jaxpr.jaxpr.effects)
if jaxpr.jaxpr.debug_info.arg_names is None:
new_arg_names = None
else:
new_arg_names = tuple(_move_to_front(jaxpr.jaxpr.debug_info.arg_names, to_move))
dbg = jaxpr.jaxpr.debug_info._replace(arg_names=new_arg_names)
new_jaxpr = jaxpr.jaxpr.replace(
constvars=constvars, invars=new_invars, effects=new_effs, debug_info=dbg)
constvars=constvars, invars=new_invars, debug_info=dbg)
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)

def _renumber_effects(new_vars, old_vars, effs):
newvar_idxs = {id(v): i for i, v in enumerate(new_vars)}
old_to_new = {i: newvar_idxs[id(v)] for i, v in enumerate(old_vars)}
return {e.replace(input_index=old_to_new[e.input_index])
if isinstance(e, effects.JaxprInputEffect) else e for e in effs}

def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence:
return ([elt for elt, move in zip(lst, to_move) if move] +
[elt for elt, move in zip(lst, to_move) if not move])
Expand Down Expand Up @@ -1763,35 +1751,6 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
all_vars[outvar] = None # type: ignore
mut_arrays.add(outvar)
for eff in eqn.effects:
if isinstance(eff, effects.JaxprInputEffect):
if eff.input_index >= len(eqn.invars):
# TODO(mattjj): ask for forgiveness
dbg = type('Fake', (), {'resolve_result_paths': lambda self_: self_,
'assert_arg_names': lambda _, __: None,
'assert_result_paths': lambda _, __: None,
})()
raise ValueError(
f"`JaxprInputEffect` {eff} is invalid."
f"\n Equation: {eqn}\n"
"\n Jaxpr: "
f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore
eqn_invar = eqn.invars[eff.input_index]
if type(eqn_invar) is core.Literal or eqn_invar in mut_arrays:
continue
if (input_index := all_vars.get(eqn_invar, sentinel)) is sentinel:
# TODO(mattjj): ask for forgiveness
dbg = type('Fake', (), {'resolve_result_paths': lambda self_: self_,
'assert_arg_names': lambda _, __: None,
'assert_result_paths': lambda _, __: None,
})()
raise ValueError(
f"`JaxprInputEffect` {eff} does not have "
f"corresponding jaxpr input: {eqn_invar=}."
f"\n Equation: {eqn}\n"
f"\n Effects: {eqn.effects}\n"
"\n Jaxpr: "
f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore
eff = eff.replace(input_index=input_index)
jaxpr_effects.add(eff)
return jaxpr_effects

Expand Down
10 changes: 2 additions & 8 deletions jax/_src/lax/control_flow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@ def _pad_constvars(jaxpr: core.Jaxpr, left: tuple[core.AvalQDD, ...],
def make_var(aq):
return core.Var(aq.aval, initial_qdd=aq.qdd, final_qdd=aq.qdd)
constvars = [*map(make_var, left), *jaxpr.constvars, *map(make_var, right)]
effs = pe._renumber_effects([*constvars, *jaxpr.invars],
[*jaxpr.constvars, *jaxpr.invars], jaxpr.effects)
jaxpr = jaxpr.replace(constvars=constvars, effects=effs)
jaxpr = jaxpr.replace(constvars=constvars)
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr

Expand All @@ -112,11 +110,7 @@ def _dedup_consts(jaxpr, const_ids):
outvars = [canonicalize.get(x, x) if isinstance(x, core.Var) else x
for x in jaxpr.outvars]
constvars = list(newvars.values())
effs = pe._renumber_effects(
[*constvars, *jaxpr.invars],
[*map(canonicalize.get, jaxpr.constvars), *jaxpr.invars], jaxpr.effects)
jaxpr = jaxpr.replace(constvars=constvars, eqns=eqns, outvars=outvars,
effects=effs)
jaxpr = jaxpr.replace(constvars=constvars, eqns=eqns, outvars=outvars)
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr

Expand Down
3 changes: 0 additions & 3 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,6 @@ def _join_cond_effects(branches: Sequence[core.ClosedJaxpr]) -> effects.Effects:
joined_effects = set()
for b in branches:
for eff in b.effects:
if isinstance(eff, effects.JaxprInputEffect):
# Offset index to handle predicate
eff = eff.replace(input_index=eff.input_index + 1)
joined_effects.add(eff)
return joined_effects

Expand Down
53 changes: 3 additions & 50 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,13 +919,7 @@ def _rearrange_mutable_binders(
immut_names, mut_names = partition_list(is_mutable, names)
new_arg_names = [*fst, *mut_names, *immut_names, *rst]
dbg = jaxpr.jaxpr.debug_info._replace(arg_names=new_arg_names)

# TODO(mattjj): don't we need to re-number effects? test coverage?
new_effs = pe._renumber_effects((*jaxpr.jaxpr.constvars, *new_invars),
(*jaxpr.jaxpr.constvars, *jaxpr.jaxpr.invars),
jaxpr.jaxpr.effects)
new_jaxpr = jaxpr.jaxpr.replace(invars=new_invars, effects=new_effs,
debug_info=dbg)
new_jaxpr = jaxpr.jaxpr.replace(invars=new_invars, debug_info=dbg)
if config.enable_checks.value: core.check_jaxpr(new_jaxpr)
return ClosedJaxpr(new_jaxpr, jaxpr.consts)

Expand Down Expand Up @@ -1777,16 +1771,8 @@ def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts
) -> effects.Effects:
joined_effects = set()
for eff in cond_jaxpr.effects:
if isinstance(eff, effects.JaxprInputEffect):
index = eff.input_index
if index >= cond_nconsts:
index += body_nconsts
eff = eff.replace(input_index=index)
joined_effects.add(eff)
for eff in body_jaxpr.effects:
if isinstance(eff, effects.JaxprInputEffect):
index = eff.input_index + cond_nconsts
eff = eff.replace(input_index=index)
joined_effects.add(eff)
return joined_effects

Expand Down Expand Up @@ -2285,17 +2271,6 @@ def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args,
[cond_nconsts,
body_nconsts])

# Check if the same Ref is written to in both cond and body.
cond_write_ids = {id(cond_consts_avals[effect.input_index])
for effect in cond_jaxpr.effects if isinstance(effect, state.WriteEffect)}
cond_has_writes = len(cond_write_ids) > 0
body_write_ids = {id(body_consts_avals[effect.input_index])
for effect in body_jaxpr.effects if isinstance(effect, state.WriteEffect)}
write_to_both_ids = cond_write_ids & body_write_ids
if write_to_both_ids:
raise NotImplementedError(
"Cannot write to the same ref in both cond and body of while loop.")

cond_is_ref = [
isinstance(aval, state.AbstractRef) and should
for aval, should in zip(cond_consts_avals, cond_consts_discharge)
Expand All @@ -2315,13 +2290,6 @@ def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args,
num_body_refs = sum(body_is_ref)
num_remaining_body_consts = body_nconsts - num_body_refs
num_out_body_consts = num_remaining_body_consts
if cond_has_writes:
# If the cond has writes, we need to add the cond consts into the body
# consts since we need to evaluate the cond condition in the body.
remaining_body_consts = [*remaining_cond_consts, *remaining_body_consts]
remaining_body_const_avals = [*remaining_cond_const_avals,
*remaining_body_const_avals]
num_remaining_body_consts += num_remaining_cond_consts

num_carry = len(in_avals) - body_nconsts - cond_nconsts
body_jaxpr, body_jaxpr_consts = body_jaxpr.jaxpr, body_jaxpr.consts
Expand Down Expand Up @@ -2353,23 +2321,8 @@ def new_body(*consts_refs_carry):
consts, body_refs, cond_refs, carry = split_list(
consts_refs_carry,
[num_remaining_body_consts, num_body_refs, num_cond_refs])
if cond_has_writes:
# We run the cond jaxpr in the body so that Refs that are updated
# in the cond jaxpr are persisted via the carry.
cond_consts, body_consts = split_list(consts, [num_remaining_cond_consts])
cond_consts_and_refs = merge_lists(cond_is_ref, cond_consts, cond_refs)
cond_carry_refs = core.eval_jaxpr(discharged_cond_jaxpr, (),
*cond_consts_and_refs,
*carry)
# Note: in order to handle the same Ref being updated in both the cond
# and body, we would need to interleave the updated cond_carry_refs into
# body_refs here.
# Currently we disallow this so we don't need to handle it.
_, cond_refs_out = split_list(cond_carry_refs, [1])
assert len(cond_refs_out) == len(cond_refs)
else:
body_consts = consts
cond_refs_out = cond_refs
body_consts = consts
cond_refs_out = cond_refs

body_consts_and_refs = merge_lists(body_is_ref, body_consts, body_refs)
body_carry_refs = core.eval_jaxpr(discharged_body_jaxpr, (),
Expand Down
13 changes: 2 additions & 11 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,18 +1833,9 @@ def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params):


def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_):
effs = _pjit_eqn_effects(jaxpr) if jaxpr.constvars else jaxpr.effects
return jaxpr.out_avals, effs
return jaxpr.out_avals, jaxpr.effects
jit_p.def_effectful_abstract_eval(_pjit_abstract_eval)

def _pjit_eqn_effects(jaxpr):
# jaxpr input effects are indexed to include jaxpr.constvars, but the pjit eqn
# should have effects indexed only on its explicit arguments
effs = jaxpr.effects
return {e.replace(input_index=e.input_index - len(jaxpr.constvars))
if isinstance(e, effects.JaxprInputEffect) else e for e in effs}


def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext,
name: str, jaxpr: core.ClosedJaxpr,
num_const_args: int, in_avals,
Expand Down Expand Up @@ -2512,7 +2503,7 @@ def keep_where(xs, keeps):
if not any(used_inputs) and not any(used_outputs) and not dced_jaxpr.effects:
return used_inputs, None
else:
new_effs = _pjit_eqn_effects(dced_jaxpr)
new_effs = dced_jaxpr.effects
new_eqn = core.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
Expand Down
1 change: 0 additions & 1 deletion jax/_src/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Transform as Transform,
TransformedRef as TransformedRef,
WriteEffect as WriteEffect,
get_ref_state_effects as get_ref_state_effects,
get_transforms_shape as get_transforms_shape,
shaped_array_ref as shaped_array_ref,
)
31 changes: 1 addition & 30 deletions jax/_src/state/discharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,39 +685,10 @@ def _run_state_abstract_eval(*avals: core.AbstractValue, jaxpr: core.Jaxpr,
is_initialized: tuple[bool, ...]):
del which_linear
assert sum(is_initialized) == len(avals)
# When we abstractly evaluate `run_state`, we want to keep track of which
# input avals are `Ref`s and which are not. If an aval is a `Ref`, we want to
# "propagate" out its inner effects. Otherwise, the effects are local to this
# `run_state`.
inner_to_outer_aval_mapping = {}
outer_ref_index = 0
for i, is_init in enumerate(is_initialized):
if not is_init:
pass
inner_to_outer_aval_mapping[i] = outer_ref_index
outer_ref_index += 1
nonlocal_effects = set()
is_ref = {i for i, aval in enumerate(avals) if isinstance(aval, AbstractRef)}
for eff in jaxpr.effects:
if not isinstance(eff, RefEffect):
nonlocal_effects.add(eff)
continue
if eff.input_index not in inner_to_outer_aval_mapping:
# This means that this effect corresponds to an uninitialized Ref and
# should not propagate out of the primitive.
continue
# If we do propagate the effect, we need to update the input index to
# correspond to the outer index.
outer_index = inner_to_outer_aval_mapping[eff.input_index]
if outer_index in is_ref:
# This means that the effect corresponds to a Ref from an outside scope.
nonlocal_effects.add(
eff.replace(input_index=inner_to_outer_aval_mapping[eff.input_index])
)
assert len(jaxpr.invars) == len(is_initialized)
if not all(is_initialized):
raise NotImplementedError # Uninitialized refs are not in avals.
return avals, nonlocal_effects
return avals, jaxpr.effects
run_state_p.def_effectful_abstract_eval(_run_state_abstract_eval)

def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *,
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/state/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args,
if transforms:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
out_aval = ref_aval.inner_aval
return (out_aval, {ReadEffect(0)})
return (out_aval, {ReadEffect()})
get_p.def_effectful_abstract_eval(_get_abstract_eval)

def _swap_abstract_eval(ref_aval: AbstractRef,
Expand Down Expand Up @@ -471,7 +471,7 @@ def _swap_abstract_eval(ref_aval: AbstractRef,
if transforms:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
out_aval = ref_aval.inner_aval
return (out_aval, {WriteEffect(0)})
return (out_aval, {WriteEffect()})
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)


Expand Down Expand Up @@ -508,7 +508,7 @@ def _addupdate_abstract_eval(ref_aval: AbstractRef,
# Check that the transforms are valid
if transforms:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
return [], {AccumEffect(0)}
return [], {AccumEffect()}
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)

## Pretty printing for `get` and `swap` in jaxprs
Expand Down
Loading
Loading