Skip to content

Commit 161b183

Browse files
committed
Seeing what happens if we get rid of positional effects
1 parent f3bb52b commit 161b183

File tree

13 files changed

+31
-213
lines changed

13 files changed

+31
-213
lines changed

jax/_src/core.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3462,19 +3462,7 @@ def write(v: Var, a: AvalQDD) -> None:
34623462
f"Equation effects: {eqn.effects}. "
34633463
f"Inferred effects: {eqn_effects}")
34643464
for eff in eqn.effects:
3465-
if isinstance(eff, effects.JaxprInputEffect):
3466-
eqn_invar = eqn.invars[eff.input_index]
3467-
if type(eqn_invar) is Literal or eqn_invar in mut_arrays:
3468-
continue
3469-
if (jaxpr_index := in_idx.get(eqn_invar, sentinel)) is sentinel:
3470-
raise JaxprTypeError(
3471-
"Invalid `JaxprInputEffect`: must correspond to a jaxpr invar")
3472-
jaxpr_effect = eff.replace(input_index=jaxpr_index)
3473-
if jaxpr_effect not in jaxpr.effects:
3474-
raise JaxprTypeError(
3475-
"Invalid `JaxprInputEffect`: must be present in jaxpr. "
3476-
f"{jaxpr_effect} is not in {jaxpr.effects}.")
3477-
elif isinstance(eff, NamedAxisEffect):
3465+
if isinstance(eff, NamedAxisEffect):
34783466
# It is valid for a primitive to discharge the named axis effect.
34793467
continue
34803468
elif eff not in jaxpr.effects:
@@ -3604,13 +3592,7 @@ def substitute(aval: AbstractValue):
36043592
if type(d) is Var else d for d in a.shape))
36053593
if type(a) is DShapedArray else a for a in out_avals]
36063594

3607-
# jaxpr input effects are indexed to include jaxpr.constvars, but the eqn
3608-
# should have effects indexed only on its explicit arguments
3609-
effs = {e.replace(input_index=e.input_index - len(call_jaxpr.constvars))
3610-
if isinstance(e, effects.JaxprInputEffect)
3611-
else e for e in call_jaxpr.effects}
3612-
3613-
return out_type, effs
3595+
return out_type, call_jaxpr.effects
36143596

36153597
def _check_map(ctx_factory, prim, in_avals, params):
36163598
if "call_jaxpr" not in params:

jax/_src/effects.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
from __future__ import annotations
5454

5555
from collections.abc import Iterable, Set
56-
from typing import Any
5756

5857

5958
class Effect:
@@ -66,29 +65,18 @@ class JaxprInputEffect(Effect):
6665
6766
This is used as a base class for effects associated with inputs, e.g.,
6867
reading/writing from mutable inputs.
69-
70-
When used in a `JaxprEqn`, `input_index` refers to `eqn.invars`.
71-
When used in a `Jaxpr`, `input_index` refers to `jaxpr.constvars + jaxpr.invars`.
7268
"""
7369

74-
def __init__(self, input_index: Any):
75-
self.input_index = input_index
76-
77-
def replace(self, *, input_index: Any | None = None):
78-
if input_index is None:
79-
input_index = self.input_index
80-
return self.__class__(input_index)
81-
8270
def __eq__(self, other):
8371
if not isinstance(other, JaxprInputEffect):
8472
return NotImplemented
85-
return self.input_index == other.input_index
73+
return True
8674

8775
def __hash__(self):
88-
return hash((self.__class__, self.input_index))
76+
return hash(self.__class__)
8977

9078
def __repr__(self):
91-
return f"{self.__class__.__name__}({self.input_index})"
79+
return f"{self.__class__.__name__}()"
9280

9381
class EffectTypeSet:
9482

jax/_src/interpreters/ad.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,11 +1473,8 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_
14731473
new_debug_info = jaxpr.jaxpr.debug_info._replace(
14741474
arg_names=new_arg_names, result_paths=new_result_paths)
14751475
constvars = jaxpr.jaxpr.constvars
1476-
new_effects = pe._renumber_effects(
1477-
(*constvars, *new_invars), (*constvars, *jaxpr.jaxpr.invars),
1478-
jaxpr.jaxpr.effects)
14791476
new_jaxpr = core.Jaxpr(constvars, new_invars, new_outvars, jaxpr.jaxpr.eqns,
1480-
new_effects, new_debug_info)
1477+
jaxpr.jaxpr.effects, new_debug_info)
14811478
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)
14821479

14831480
def _perm(primal_counts: Sequence[int], tangent_counts: Sequence[int],

jax/_src/interpreters/partial_eval.py

Lines changed: 2 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,11 +1601,7 @@ def _move_invars_right(jaxpr: ClosedJaxpr, to_move: tuple[bool, ...]):
16011601
invars, rest = split_list(jaxpr.jaxpr.invars, [len(to_move)])
16021602
left_invars, right_invars = partition_list(to_move, invars)
16031603
new_invars = [*left_invars, *right_invars, *rest]
1604-
new_effs = _renumber_effects(
1605-
(*jaxpr.jaxpr.constvars, *new_invars),
1606-
(*jaxpr.jaxpr.constvars, *jaxpr.jaxpr.invars),
1607-
jaxpr.jaxpr.effects)
1608-
new_jaxpr = jaxpr.jaxpr.replace(invars=new_invars, effects=new_effs)
1604+
new_jaxpr = jaxpr.jaxpr.replace(invars=new_invars)
16091605
return jaxpr.replace(jaxpr=new_jaxpr)
16101606

16111607
def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]
@@ -1619,23 +1615,15 @@ def _move_binders_to_front(jaxpr: ClosedJaxpr, to_move: tuple[bool, ...]
16191615
assert len(jaxpr.in_avals) == len(to_move)
16201616
constvars, invars = jaxpr.jaxpr.constvars, jaxpr.jaxpr.invars
16211617
new_invars = _move_to_front(invars, to_move)
1622-
new_effs = _renumber_effects(
1623-
(*constvars, *new_invars), (*constvars, *invars), jaxpr.jaxpr.effects)
16241618
if jaxpr.jaxpr.debug_info.arg_names is None:
16251619
new_arg_names = None
16261620
else:
16271621
new_arg_names = tuple(_move_to_front(jaxpr.jaxpr.debug_info.arg_names, to_move))
16281622
dbg = jaxpr.jaxpr.debug_info._replace(arg_names=new_arg_names)
16291623
new_jaxpr = jaxpr.jaxpr.replace(
1630-
constvars=constvars, invars=new_invars, effects=new_effs, debug_info=dbg)
1624+
constvars=constvars, invars=new_invars, debug_info=dbg)
16311625
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)
16321626

1633-
def _renumber_effects(new_vars, old_vars, effs):
1634-
newvar_idxs = {id(v): i for i, v in enumerate(new_vars)}
1635-
old_to_new = {i: newvar_idxs[id(v)] for i, v in enumerate(old_vars)}
1636-
return {e.replace(input_index=old_to_new[e.input_index])
1637-
if isinstance(e, effects.JaxprInputEffect) else e for e in effs}
1638-
16391627
def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence:
16401628
return ([elt for elt, move in zip(lst, to_move) if move] +
16411629
[elt for elt, move in zip(lst, to_move) if not move])
@@ -1763,35 +1751,6 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
17631751
all_vars[outvar] = None # type: ignore
17641752
mut_arrays.add(outvar)
17651753
for eff in eqn.effects:
1766-
if isinstance(eff, effects.JaxprInputEffect):
1767-
if eff.input_index >= len(eqn.invars):
1768-
# TODO(mattjj): ask for forgiveness
1769-
dbg = type('Fake', (), {'resolve_result_paths': lambda self_: self_,
1770-
'assert_arg_names': lambda _, __: None,
1771-
'assert_result_paths': lambda _, __: None,
1772-
})()
1773-
raise ValueError(
1774-
f"`JaxprInputEffect` {eff} is invalid."
1775-
f"\n Equation: {eqn}\n"
1776-
"\n Jaxpr: "
1777-
f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore
1778-
eqn_invar = eqn.invars[eff.input_index]
1779-
if type(eqn_invar) is core.Literal or eqn_invar in mut_arrays:
1780-
continue
1781-
if (input_index := all_vars.get(eqn_invar, sentinel)) is sentinel:
1782-
# TODO(mattjj): ask for forgiveness
1783-
dbg = type('Fake', (), {'resolve_result_paths': lambda self_: self_,
1784-
'assert_arg_names': lambda _, __: None,
1785-
'assert_result_paths': lambda _, __: None,
1786-
})()
1787-
raise ValueError(
1788-
f"`JaxprInputEffect` {eff} does not have "
1789-
f"corresponding jaxpr input: {eqn_invar=}."
1790-
f"\n Equation: {eqn}\n"
1791-
f"\n Effects: {eqn.effects}\n"
1792-
"\n Jaxpr: "
1793-
f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore
1794-
eff = eff.replace(input_index=input_index)
17951754
jaxpr_effects.add(eff)
17961755
return jaxpr_effects
17971756

jax/_src/lax/control_flow/common.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,7 @@ def _pad_constvars(jaxpr: core.Jaxpr, left: tuple[core.AvalQDD, ...],
9696
def make_var(aq):
9797
return core.Var(aq.aval, initial_qdd=aq.qdd, final_qdd=aq.qdd)
9898
constvars = [*map(make_var, left), *jaxpr.constvars, *map(make_var, right)]
99-
effs = pe._renumber_effects([*constvars, *jaxpr.invars],
100-
[*jaxpr.constvars, *jaxpr.invars], jaxpr.effects)
101-
jaxpr = jaxpr.replace(constvars=constvars, effects=effs)
99+
jaxpr = jaxpr.replace(constvars=constvars)
102100
config.enable_checks.value and core.check_jaxpr(jaxpr)
103101
return jaxpr
104102

@@ -112,11 +110,7 @@ def _dedup_consts(jaxpr, const_ids):
112110
outvars = [canonicalize.get(x, x) if isinstance(x, core.Var) else x
113111
for x in jaxpr.outvars]
114112
constvars = list(newvars.values())
115-
effs = pe._renumber_effects(
116-
[*constvars, *jaxpr.invars],
117-
[*map(canonicalize.get, jaxpr.constvars), *jaxpr.invars], jaxpr.effects)
118-
jaxpr = jaxpr.replace(constvars=constvars, eqns=eqns, outvars=outvars,
119-
effects=effs)
113+
jaxpr = jaxpr.replace(constvars=constvars, eqns=eqns, outvars=outvars)
120114
config.enable_checks.value and core.check_jaxpr(jaxpr)
121115
return jaxpr
122116

jax/_src/lax/control_flow/conditionals.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,6 @@ def _join_cond_effects(branches: Sequence[core.ClosedJaxpr]) -> effects.Effects:
445445
joined_effects = set()
446446
for b in branches:
447447
for eff in b.effects:
448-
if isinstance(eff, effects.JaxprInputEffect):
449-
# Offset index to handle predicate
450-
eff = eff.replace(input_index=eff.input_index + 1)
451448
joined_effects.add(eff)
452449
return joined_effects
453450

jax/_src/lax/control_flow/loops.py

Lines changed: 3 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -919,13 +919,7 @@ def _rearrange_mutable_binders(
919919
immut_names, mut_names = partition_list(is_mutable, names)
920920
new_arg_names = [*fst, *mut_names, *immut_names, *rst]
921921
dbg = jaxpr.jaxpr.debug_info._replace(arg_names=new_arg_names)
922-
923-
# TODO(mattjj): don't we need to re-number effects? test coverage?
924-
new_effs = pe._renumber_effects((*jaxpr.jaxpr.constvars, *new_invars),
925-
(*jaxpr.jaxpr.constvars, *jaxpr.jaxpr.invars),
926-
jaxpr.jaxpr.effects)
927-
new_jaxpr = jaxpr.jaxpr.replace(invars=new_invars, effects=new_effs,
928-
debug_info=dbg)
922+
new_jaxpr = jaxpr.jaxpr.replace(invars=new_invars, debug_info=dbg)
929923
if config.enable_checks.value: core.check_jaxpr(new_jaxpr)
930924
return ClosedJaxpr(new_jaxpr, jaxpr.consts)
931925

@@ -1777,16 +1771,8 @@ def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts
17771771
) -> effects.Effects:
17781772
joined_effects = set()
17791773
for eff in cond_jaxpr.effects:
1780-
if isinstance(eff, effects.JaxprInputEffect):
1781-
index = eff.input_index
1782-
if index >= cond_nconsts:
1783-
index += body_nconsts
1784-
eff = eff.replace(input_index=index)
17851774
joined_effects.add(eff)
17861775
for eff in body_jaxpr.effects:
1787-
if isinstance(eff, effects.JaxprInputEffect):
1788-
index = eff.input_index + cond_nconsts
1789-
eff = eff.replace(input_index=index)
17901776
joined_effects.add(eff)
17911777
return joined_effects
17921778

@@ -2285,17 +2271,6 @@ def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args,
22852271
[cond_nconsts,
22862272
body_nconsts])
22872273

2288-
# Check if the same Ref is written to in both cond and body.
2289-
cond_write_ids = {id(cond_consts_avals[effect.input_index])
2290-
for effect in cond_jaxpr.effects if isinstance(effect, state.WriteEffect)}
2291-
cond_has_writes = len(cond_write_ids) > 0
2292-
body_write_ids = {id(body_consts_avals[effect.input_index])
2293-
for effect in body_jaxpr.effects if isinstance(effect, state.WriteEffect)}
2294-
write_to_both_ids = cond_write_ids & body_write_ids
2295-
if write_to_both_ids:
2296-
raise NotImplementedError(
2297-
"Cannot write to the same ref in both cond and body of while loop.")
2298-
22992274
cond_is_ref = [
23002275
isinstance(aval, state.AbstractRef) and should
23012276
for aval, should in zip(cond_consts_avals, cond_consts_discharge)
@@ -2315,13 +2290,6 @@ def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args,
23152290
num_body_refs = sum(body_is_ref)
23162291
num_remaining_body_consts = body_nconsts - num_body_refs
23172292
num_out_body_consts = num_remaining_body_consts
2318-
if cond_has_writes:
2319-
# If the cond has writes, we need to add the cond consts into the body
2320-
# consts since we need to evaluate the cond condition in the body.
2321-
remaining_body_consts = [*remaining_cond_consts, *remaining_body_consts]
2322-
remaining_body_const_avals = [*remaining_cond_const_avals,
2323-
*remaining_body_const_avals]
2324-
num_remaining_body_consts += num_remaining_cond_consts
23252293

23262294
num_carry = len(in_avals) - body_nconsts - cond_nconsts
23272295
body_jaxpr, body_jaxpr_consts = body_jaxpr.jaxpr, body_jaxpr.consts
@@ -2353,23 +2321,8 @@ def new_body(*consts_refs_carry):
23532321
consts, body_refs, cond_refs, carry = split_list(
23542322
consts_refs_carry,
23552323
[num_remaining_body_consts, num_body_refs, num_cond_refs])
2356-
if cond_has_writes:
2357-
# We run the cond jaxpr in the body so that Refs that are updated
2358-
# in the cond jaxpr are persisted via the carry.
2359-
cond_consts, body_consts = split_list(consts, [num_remaining_cond_consts])
2360-
cond_consts_and_refs = merge_lists(cond_is_ref, cond_consts, cond_refs)
2361-
cond_carry_refs = core.eval_jaxpr(discharged_cond_jaxpr, (),
2362-
*cond_consts_and_refs,
2363-
*carry)
2364-
# Note: in order to handle the same Ref being updated in both the cond
2365-
# and body, we would need to interleave the updated cond_carry_refs into
2366-
# body_refs here.
2367-
# Currently we disallow this so we don't need to handle it.
2368-
_, cond_refs_out = split_list(cond_carry_refs, [1])
2369-
assert len(cond_refs_out) == len(cond_refs)
2370-
else:
2371-
body_consts = consts
2372-
cond_refs_out = cond_refs
2324+
body_consts = consts
2325+
cond_refs_out = cond_refs
23732326

23742327
body_consts_and_refs = merge_lists(body_is_ref, body_consts, body_refs)
23752328
body_carry_refs = core.eval_jaxpr(discharged_body_jaxpr, (),

jax/_src/pjit.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,18 +1833,9 @@ def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params):
18331833

18341834

18351835
def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_):
1836-
effs = _pjit_eqn_effects(jaxpr) if jaxpr.constvars else jaxpr.effects
1837-
return jaxpr.out_avals, effs
1836+
return jaxpr.out_avals, jaxpr.effects
18381837
jit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
18391838

1840-
def _pjit_eqn_effects(jaxpr):
1841-
# jaxpr input effects are indexed to include jaxpr.constvars, but the pjit eqn
1842-
# should have effects indexed only on its explicit arguments
1843-
effs = jaxpr.effects
1844-
return {e.replace(input_index=e.input_index - len(jaxpr.constvars))
1845-
if isinstance(e, effects.JaxprInputEffect) else e for e in effs}
1846-
1847-
18481839
def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext,
18491840
name: str, jaxpr: core.ClosedJaxpr,
18501841
num_const_args: int, in_avals,
@@ -2512,7 +2503,7 @@ def keep_where(xs, keeps):
25122503
if not any(used_inputs) and not any(used_outputs) and not dced_jaxpr.effects:
25132504
return used_inputs, None
25142505
else:
2515-
new_effs = _pjit_eqn_effects(dced_jaxpr)
2506+
new_effs = dced_jaxpr.effects
25162507
new_eqn = core.new_jaxpr_eqn(
25172508
[v for v, used in zip(eqn.invars, used_inputs) if used],
25182509
[v for v, used in zip(eqn.outvars, used_outputs) if used],

jax/_src/state/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
Transform as Transform,
2222
TransformedRef as TransformedRef,
2323
WriteEffect as WriteEffect,
24-
get_ref_state_effects as get_ref_state_effects,
2524
get_transforms_shape as get_transforms_shape,
2625
shaped_array_ref as shaped_array_ref,
2726
)

jax/_src/state/discharge.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -685,39 +685,10 @@ def _run_state_abstract_eval(*avals: core.AbstractValue, jaxpr: core.Jaxpr,
685685
is_initialized: tuple[bool, ...]):
686686
del which_linear
687687
assert sum(is_initialized) == len(avals)
688-
# When we abstractly evaluate `run_state`, we want to keep track of which
689-
# input avals are `Ref`s and which are not. If an aval is a `Ref`, we want to
690-
# "propagate" out its inner effects. Otherwise, the effects are local to this
691-
# `run_state`.
692-
inner_to_outer_aval_mapping = {}
693-
outer_ref_index = 0
694-
for i, is_init in enumerate(is_initialized):
695-
if not is_init:
696-
pass
697-
inner_to_outer_aval_mapping[i] = outer_ref_index
698-
outer_ref_index += 1
699-
nonlocal_effects = set()
700-
is_ref = {i for i, aval in enumerate(avals) if isinstance(aval, AbstractRef)}
701-
for eff in jaxpr.effects:
702-
if not isinstance(eff, RefEffect):
703-
nonlocal_effects.add(eff)
704-
continue
705-
if eff.input_index not in inner_to_outer_aval_mapping:
706-
# This means that this effect corresponds to an uninitialized Ref and
707-
# should not propagate out of the primitive.
708-
continue
709-
# If we do propagate the effect, we need to update the input index to
710-
# correspond to the outer index.
711-
outer_index = inner_to_outer_aval_mapping[eff.input_index]
712-
if outer_index in is_ref:
713-
# This means that the effect corresponds to a Ref from an outside scope.
714-
nonlocal_effects.add(
715-
eff.replace(input_index=inner_to_outer_aval_mapping[eff.input_index])
716-
)
717688
assert len(jaxpr.invars) == len(is_initialized)
718689
if not all(is_initialized):
719690
raise NotImplementedError # Uninitialized refs are not in avals.
720-
return avals, nonlocal_effects
691+
return avals, jaxpr.effects
721692
run_state_p.def_effectful_abstract_eval(_run_state_abstract_eval)
722693

723694
def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *,

0 commit comments

Comments
 (0)