Skip to content

Commit 263d4d1

Browse files
Merge pull request jax-ml#25369 from jax-ml:mutable-arrays-ad
PiperOrigin-RevId: 704685653
2 parents 8e7aaa7 + fc2edbf commit 263d4d1

File tree

8 files changed

+76
-7
lines changed

8 files changed

+76
-7
lines changed

jax/_src/core.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,8 @@ class Primitive:
431431
call_primitive: bool = False
432432
# set for map primitives processed in final style.
433433
map_primitive: bool = False
434+
# set for ref primitives
435+
ref_primitive: bool = False
434436

435437
def __init__(self, name: str):
436438
self.name = name
@@ -1882,6 +1884,7 @@ def __repr__(self) -> str: return 'Mutable' + repr(self[...])
18821884
def mutable_array(init_val):
18831885
return mutable_array_p.bind(init_val)
18841886
mutable_array_p = Primitive('mutable_array')
1887+
mutable_array_p.ref_primitive = True
18851888

18861889
class InternalMutableArrayEffect(effects.Effect):
18871890
pass
@@ -1899,6 +1902,18 @@ def _mutable_array_impl(init_val):
18991902
aval = get_aval(init_val)
19001903
return MutableArray(AbstractRef(aval), init_val)
19011904

1905+
def freeze(ref):
1906+
return freeze_p.bind(ref)
1907+
freeze_p = Primitive('freeze')
1908+
freeze_p.ref_primitive = True
1909+
1910+
@freeze_p.def_effectful_abstract_eval
1911+
def freeze_abstract_eval(ref_aval):
1912+
return ref_aval.inner_aval, {internal_mutable_array_effect}
1913+
1914+
@freeze_p.def_impl
1915+
def _freeze_impl(ref):
1916+
return ref[()]
19021917

19031918
class AbstractToken(AbstractValue):
19041919
def str_short(self, short_dtypes=False): return 'Tok'
@@ -2516,10 +2531,11 @@ def write(v: Var, a: AbstractValue) -> None:
25162531

25172532
# Check the computed effect type matches the eqn's annotation, and is
25182533
# included in the jaxpr's annotation.
2519-
if prim is mutable_array_p:
2520-
outvar, = eqn.outvars
2521-
in_idx[outvar] = None # type: ignore
2522-
mut_arrays.add(outvar)
2534+
if prim.ref_primitive:
2535+
if prim is mutable_array_p:
2536+
outvar, = eqn.outvars
2537+
in_idx[outvar] = None # type: ignore
2538+
mut_arrays.add(outvar)
25232539
if eqn.effects != eqn_effects:
25242540
raise JaxprTypeError("Inferred effects do not match equation effects. "
25252541
f"Equation effects: {eqn.effects}. "

jax/_src/interpreters/ad.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,20 @@ def write_primal(v, val):
267267
with ctx:
268268
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
269269
for eqn in jaxpr.eqns[::-1]:
270+
if eqn.primitive.ref_primitive:
271+
if eqn.primitive is core.mutable_array_p:
272+
val_var, = eqn.invars
273+
ref_var, = eqn.outvars
274+
ref = read_primal(ref_var)
275+
ct_out = core.freeze(ref)
276+
write_cotangent(eqn.primitive, val_var, ct_out)
277+
elif eqn.primitive is core.freeze_p:
278+
val_var, = eqn.outvars
279+
ref_var, = eqn.invars
280+
ct_in = instantiate_zeros(read_cotangent(val_var))
281+
write_primal(ref_var, core.mutable_array(ct_in))
282+
continue
283+
270284
invals = map(read_primal, eqn.invars)
271285
if eqn.primitive.multiple_results:
272286
cts_in = map(read_cotangent, eqn.outvars)

jax/_src/interpreters/partial_eval.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1009,21 +1009,25 @@ def partial_eval_jaxpr_stateful(
10091009
in_inst: bool | Sequence[bool],
10101010
ensure_out_unknowns: bool | Sequence[bool],
10111011
ensure_out_inst: bool | Sequence[bool],
1012-
saveable: Callable[..., RematCases_],
1012+
saveable: Callable[..., RematCases_] | None,
10131013
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int, int]:
10141014
if type(in_inst) is bool:
10151015
in_inst = (in_inst,) * len(jaxpr.invars)
10161016
if type(ensure_out_unknowns) is bool:
10171017
ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars)
10181018
if type(ensure_out_inst) is bool:
10191019
ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars)
1020+
if saveable is None:
1021+
saveable = everything_saveable
10201022
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \
10211023
_partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns),
10221024
tuple(in_inst),
10231025
tuple(ensure_out_unknowns),
10241026
tuple(ensure_out_inst), saveable)
10251027
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref
10261028

1029+
everything_saveable = lambda *_, **__: True
1030+
10271031
@weakref_lru_cache
10281032
def _partial_eval_jaxpr_custom_cached(
10291033
jaxpr: Jaxpr,

jax/_src/lax/lax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,6 +1730,9 @@ def zeros_like_abstract_ref(aval: state.AbstractRef) -> core.MutableArray:
17301730
val = ad_util.zeros_like_aval(aval.inner_aval)
17311731
return core.mutable_array(val)
17321732

1733+
# TODO(dougalm): this is nonsense but it's here because in places like
1734+
# custom_vjp we assume that all arguments have tangent spaces. We could have
1735+
# a distinct NotATangentType value instead.
17331736
ad_util.aval_zeros_likers[state.AbstractRef] = zeros_like_abstract_ref # type: ignore
17341737

17351738
def iota(dtype: DTypeLike, size: int) -> Array:

jax/_src/pjit.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2156,8 +2156,18 @@ def _pjit_partial_eval(trace, *in_tracers,
21562156

21572157
known_ins = tuple(pv.is_known() for pv in in_pvals)
21582158
unknown_ins = tuple(not k for k in known_ins)
2159-
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
2160-
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
2159+
if any(isinstance(e, (RefEffect, core.InternalMutableArrayEffect))
2160+
for e in jaxpr.effects):
2161+
known_jaxpr_, unknown_jaxpr_, unknown_outs, _, num_res_val, num_res_ref = \
2162+
pe.partial_eval_jaxpr_stateful(jaxpr.jaxpr, unknown_ins, unknown_ins,
2163+
False, False, None)
2164+
if num_res_ref: raise NotImplementedError
2165+
known_jaxpr = pe.ClosedJaxpr(known_jaxpr_, jaxpr.consts)
2166+
unknown_jaxpr = pe.ClosedJaxpr(unknown_jaxpr_, jaxpr.consts)
2167+
res_avals = unknown_jaxpr.in_avals[:num_res_val]
2168+
else:
2169+
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
2170+
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
21612171
unknown_outs = tuple(unknown_outs)
21622172
known_outs = tuple(not uk for uk in unknown_outs)
21632173
num_residuals = len(res_avals)

jax/_src/state/discharge.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ def _eval_jaxpr_discharge_state(
153153
[invar], [outvar] = eqn.invars, eqn.outvars
154154
ans = env.read(invar)
155155
refs_to_discharge.add(id(outvar.aval))
156+
elif eqn.primitive is core.freeze_p:
157+
[invar], [outvar] = eqn.invars, eqn.outvars
158+
ans = env.read(invar)
159+
refs_to_discharge.remove(id(invar.aval))
156160
elif (any(should_discharge)
157161
or core.internal_mutable_array_effect in eqn.effects
158162
):

jax/_src/state/primitives.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,3 +654,8 @@ def _broadcast_to_abstract_eval(aval, *, shape):
654654
mlir.register_lowering(
655655
broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False)
656656
)
657+
658+
# === AD rules for mutable arrays ===
659+
660+
ad.defjvp(core.mutable_array_p, lambda g, _: core.mutable_array(g))
661+
ad.defjvp(core.freeze_p, lambda g, _: core.freeze(g))

tests/mutable_array_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,5 +232,18 @@ def f():
232232
x = f()
233233
self.assertArraysEqual(x, jnp.zeros(8))
234234

235+
def test_grad_mutable_array(self):
236+
@jax.jit
237+
def f(x):
238+
x_ = core.mutable_array(x)
239+
x_[()] = x_[()] + x_[()]
240+
y = core.freeze(x_)
241+
return y
242+
243+
ans = jax.grad(f)(1.)
244+
expected = 2.0
245+
self.assertAllClose(ans, expected, check_dtypes=False)
246+
247+
235248
if __name__ == '__main__':
236249
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)