Skip to content

Commit aaabb97

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
Partial discharge for scan_p ops.
PiperOrigin-RevId: 707558502
1 parent 2259a13 commit aaabb97

File tree

2 files changed

+57
-14
lines changed

2 files changed

+57
-14
lines changed

jax/_src/lax/control_flow/loops.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,7 +1161,7 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,
11611161
f'called with sequence whose items have type\n{_avals_short(x_avals_mapped)}')
11621162
return [*init_avals, *y_avals], jaxpr.effects
11631163

1164-
def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
1164+
def _scan_state_partial_discharge_rule(should_discharge, in_avals, out_avals, *args, jaxpr, num_consts,
11651165
num_carry, linear, unroll, reverse, length,
11661166
_split_transpose):
11671167
# We're shuffling parameters between three signatures for the scan body:
@@ -1182,39 +1182,59 @@ def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
11821182
n_ys = len(out_avals) - n_carry
11831183
consts_avals, carry_avals, xs_avals = split_list_checked(in_avals,
11841184
[n_consts, n_carry, n_xs])
1185-
is_ref_const = [isinstance(a, state.AbstractRef) for a in consts_avals]
1186-
assert not any(isinstance(a, state.AbstractRef) for a in carry_avals)
1187-
is_ref_xs = [isinstance(a, state.AbstractRef) for a in xs_avals]
1185+
consts_discharge, carry_discharge, xs_discharge = split_list_checked(should_discharge,
1186+
[n_consts, n_carry, n_xs])
1187+
1188+
is_ref_const = [s and isinstance(a, state.AbstractRef) for s, a in zip(consts_discharge, consts_avals)]
1189+
assert not any(isinstance(a, state.AbstractRef) for a in carry_avals)
1190+
assert not any(carry_discharge)
1191+
is_ref_xs = [s and isinstance(a, state.AbstractRef) for s, a in zip(xs_discharge, xs_avals)]
11881192
n_ref_consts = sum(is_ref_const)
11891193
n_val_consts = n_consts - n_ref_consts
11901194
n_ref_xs = sum(is_ref_xs)
11911195
n_val_xs = n_xs - n_ref_xs
1192-
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ())
1196+
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, (), should_discharge=should_discharge)
11931197
if discharged_consts:
11941198
raise NotImplementedError("Discharged jaxpr has consts. If you see this, "
11951199
"please open an issue at "
11961200
"https://github.com/jax-ml/jax/issues")
11971201
def wrapped(*wrapped_args):
1198-
val_consts, ref_consts_in, carry_in, val_xs, ref_xs_in = split_list_checked(wrapped_args,
1199-
[n_val_consts, n_ref_consts, n_carry, n_val_xs, n_ref_xs])
1202+
val_consts, carry_in, ref_consts_in, val_xs, ref_xs_in = split_list_checked(wrapped_args,
1203+
[n_val_consts, n_carry, n_ref_consts, n_val_xs, n_ref_xs])
12001204
consts = merge_lists(is_ref_const, val_consts, ref_consts_in)
12011205
xs = merge_lists(is_ref_xs, val_xs, ref_xs_in)
12021206
outs = core.eval_jaxpr(discharged_jaxpr, (), *consts, *carry_in, *xs)
12031207
carry_out, ys, ref_consts_out, ref_xs_out = split_list_checked(outs,
12041208
[n_carry, n_ys, n_ref_consts, n_ref_xs])
1205-
return [*ref_consts_out, *carry_out, *ys, *ref_xs_out]
1209+
return [*carry_out, *ref_consts_out, *ys, *ref_xs_out]
12061210

12071211
def arrange_jaxpr_args_for_wrapped(args):
12081212
consts, carry_in, xs = split_list_checked(args, [n_consts, n_carry, n_xs])
12091213
val_consts, ref_consts_in = partition_list(is_ref_const, consts)
12101214
val_xs, ref_xs_in = partition_list(is_ref_xs, xs)
1211-
return *val_consts, *ref_consts_in, *carry_in, *val_xs, *ref_xs_in
1215+
return *val_consts, *carry_in, *ref_consts_in, *val_xs, *ref_xs_in
12121216

1217+
# Rearrange the arguments such that they are:
1218+
# val_consts, carry, ref_consts, val_xs, ref_xs
1219+
#
1220+
# It is important that carry is immediately after the val_consts
1221+
# because pallas pattern matches the leading argument type to figure
1222+
# out if a scan_p eqn is equivalent to a fori loop (see
1223+
# `pallas.utils.pattern_match_scan_to_fori_loop()`).
12131224
args_for_wrapped = arrange_jaxpr_args_for_wrapped(args)
12141225
linear_for_wrapped = arrange_jaxpr_args_for_wrapped(linear)
12151226
avals_for_wrapped = arrange_jaxpr_args_for_wrapped(in_avals)
1216-
avals_for_wrapped_no_refs = [aval.inner_aval if isinstance(aval, state.AbstractRef) else aval
1217-
for aval in avals_for_wrapped]
1227+
# Get the const avals that we need to discharge and leave the rest as-is.
1228+
deref_const_avals = tuple(c.inner_aval for c in avals_for_wrapped[n_val_consts + n_carry:n_consts + n_carry])
1229+
deref_xs_avals = tuple(x.inner_aval for x in avals_for_wrapped[n_consts + n_carry + n_val_xs:])
1230+
avals_for_wrapped_no_refs = (
1231+
avals_for_wrapped[: n_val_consts + n_carry]
1232+
+ deref_const_avals
1233+
+ avals_for_wrapped[n_consts + n_carry :n_consts + n_carry + n_val_xs]
1234+
+ deref_xs_avals
1235+
)
1236+
# TODO(cperivol): avoid tracing the jaxpr twice. When doing so don't
1237+
# forget to manage the effects.
12181238
new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), avals_for_wrapped_no_refs)
12191239
all_out = scan_p.bind(*args_for_wrapped,
12201240
jaxpr=core.ClosedJaxpr(new_jaxpr, ()),
@@ -1224,8 +1244,8 @@ def arrange_jaxpr_args_for_wrapped(args):
12241244
unroll=unroll,
12251245
reverse=reverse,
12261246
linear=linear_for_wrapped, _split_transpose=_split_transpose)
1227-
ref_consts_out, carry_out, ys, ref_xs_out = split_list_checked(all_out,
1228-
[n_ref_consts, n_carry, n_ys, n_ref_xs])
1247+
carry_out, ref_consts_out, ys, ref_xs_out = split_list_checked(all_out,
1248+
[n_carry, n_ref_consts, n_ys, n_ref_xs])
12291249
refs_out_matching_in_avals = [
12301250
*merge_lists(is_ref_const, [None] * n_val_consts, ref_consts_out),
12311251
*[None] * n_carry,
@@ -1248,7 +1268,7 @@ def arrange_jaxpr_args_for_wrapped(args):
12481268
pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
12491269
pe.padding_rules[scan_p] = _scan_padding_rule
12501270
pe.dce_rules[scan_p] = _scan_dce_rule
1251-
state_discharge.register_discharge_rule(scan_p)(_scan_state_discharge_rule)
1271+
state_discharge.register_partial_discharge_rule(scan_p)(_scan_state_partial_discharge_rule)
12521272

12531273
def _propagate_mem_kind_scan(*xm, reverse, length, num_consts, num_carry, jaxpr,
12541274
linear, unroll, _split_transpose):

tests/state_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,29 @@ def f(a_ref, b_ref):
776776
self.assertEqual(prim_count(swap_p, jaxpr) // 2, prim_count(swap_p, discharged_jaxpr))
777777
self.assertEqual(prim_count(get_p, jaxpr) // 2, prim_count(get_p, discharged_jaxpr))
778778

779+
def test_partial_fori_discharge(self):
780+
def f(a_ref, b_ref):
781+
def body(i, st):
782+
a_ref[...] += 2 * i
783+
b_ref[...] += i
784+
return ()
785+
lax.fori_loop(0, 5, body, init_val=())
786+
return a_ref[...], b_ref[...]
787+
788+
ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x)))
789+
f_jaxpr = jax.make_jaxpr(f)(ref(1.), ref(2.))
790+
jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, True])
791+
# Effects on y_ref were discharged away but not the effects on x_ref
792+
self.assertEqual(f_jaxpr.effects, {ReadEffect(0), WriteEffect(0), ReadEffect(1), WriteEffect(1)})
793+
self.assertEqual(jaxpr.effects, {ReadEffect(0), WriteEffect(0)})
794+
# x_ref arg is still a reference but y_ref is discharged
795+
self.assertNotIsInstance(jaxpr.invars[1].aval, AbstractRef)
796+
self.assertIsInstance(jaxpr.invars[0].aval, AbstractRef)
797+
# x_ref value is returned as part of the discharged refs set.
798+
self.assertLen(f_jaxpr.out_avals, 2)
799+
self.assertLen(jaxpr.outvars, 3)
800+
801+
779802
if CAN_USE_HYPOTHESIS:
780803

781804
def index_arrays(size, idx_shape):

0 commit comments

Comments
 (0)