Skip to content

Commit 6fba4ec

Browse files
mattjjGoogle-ML-Automation
authored andcommitted
PR jax-ml#27576: [attrs] experimental appendattr
Imported from GitHub PR jax-ml#27576 This is an experimental extension to attrs. Attrs should be considered both experimental and deprecated. This PR also includes some fixes for getattr/setattr. Copybara import of the project: -- 3b1ea1a by Matthew Johnson <[email protected]>: [attrs] experimental appendattr Merging this change closes jax-ml#27576 COPYBARA_INTEGRATE_REVIEW=jax-ml#27576 from mattjj:appendattr b937952 PiperOrigin-RevId: 741662724
1 parent 1771936 commit 6fba4ec

File tree

5 files changed

+430
-90
lines changed

5 files changed

+430
-90
lines changed

jax/_src/interpreters/partial_eval.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ def identity(x): return x
5858
AvalId = int
5959
ConstId = int
6060

61+
AttrKind = Any
62+
PyTree = Any
63+
64+
# Attrs flavors, see jax/experimental/attrs.py
65+
ReadWrite = type('ReadWrite', (), {})()
66+
Append = type('Append', (), {})()
67+
6168
def _update_annotation_known(
6269
f: lu.WrappedFun,
6370
orig_type: InputType | None,
@@ -1553,6 +1560,17 @@ def move_binders_to_back(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]
15531560
"""Reorder `invars` by moving those indicated in `to_move` to the back."""
15541561
return move_binders_to_front(closed_jaxpr, map(op.not_, to_move))
15551562

1563+
def move_outvars_to_back(jaxpr: ClosedJaxpr, to_move: Sequence[bool]) -> ClosedJaxpr:
1564+
return _move_outvars_to_back(jaxpr, tuple(to_move))
1565+
1566+
@weakref_lru_cache
1567+
def _move_outvars_to_back(jaxpr, to_move):
1568+
new_outvars = ([e for e, m in zip(jaxpr.jaxpr.outvars, to_move) if not m] +
1569+
[e for e, m in zip(jaxpr.jaxpr.outvars, to_move) if m])
1570+
return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(outvars=new_outvars))
1571+
1572+
1573+
15561574
class DynamicJaxprTracer(core.Tracer):
15571575
__slots__ = ['aval', '_debug_info']
15581576

@@ -1657,7 +1675,7 @@ class JaxprStackFrame:
16571675
eqns: list[JaxprEqn]
16581676
invars: list[Var]
16591677
effects: core.Effects
1660-
attrs_tracked: list[tuple[Any, str]]
1678+
attrs_tracked: list[tuple[Any, str, AttrKind]]
16611679
attrs_inits: list
16621680
attrs_vars: list[Var]
16631681
debug_info: core.DebugInfo
@@ -1679,10 +1697,14 @@ def __init__(self, debug_info: core.DebugInfo):
16791697
def add_eqn(self, eqn: core.JaxprEqn):
16801698
self.eqns.append(eqn)
16811699

1682-
def to_jaxpr(self, trace: DynamicJaxprTrace,
1683-
out_tracers: Sequence[Tracer],
1684-
debug_info: core.DebugInfo,
1685-
) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
1700+
def reset_states(self):
1701+
reset_states(self.attrs_tracked, self.attrs_inits)
1702+
1703+
def to_jaxpr(
1704+
self, trace: DynamicJaxprTrace,
1705+
out_tracers: Sequence[Tracer],
1706+
debug_info: core.DebugInfo,
1707+
) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]:
16861708
# It's not necessary, but we keep the tracer-to-var mapping injective:
16871709
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
16881710
invars = self.attrs_vars + self.invars
@@ -1699,7 +1721,6 @@ def to_jaxpr(self, trace: DynamicJaxprTrace,
16991721
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
17001722
jaxpr, constvals = _inline_literals(jaxpr, constvals)
17011723
init_trees = [tree_structure(init_val) for init_val in self.attrs_inits]
1702-
set_states(self.attrs_tracked, self.attrs_inits) # reset to initial values
17031724
return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked)
17041725

17051726
def to_jaxpr2(self, out_tracers: Sequence[core.Tracer],
@@ -1840,10 +1861,9 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:
18401861
outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars]
18411862
new_eqns.append(eqn.replace(invars=invars, outvars=outvars))
18421863
new_outvars = [lit_or_var(v) for v in jaxpr.outvars]
1843-
jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars,
1844-
new_eqns)
1845-
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns,
1846-
jaxpr_effects, jaxpr.debug_info)
1864+
effs = make_jaxpr_effects(new_constvars, new_invars, new_outvars, new_eqns)
1865+
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, effs,
1866+
jaxpr.debug_info)
18471867
return new_jaxpr, new_constvals
18481868

18491869

@@ -2172,19 +2192,23 @@ def trace_to_jaxpr_dynamic(
21722192
*,
21732193
keep_inputs: list[bool] | None = None,
21742194
) -> tuple[Jaxpr, list[AbstractValue], list[Any],
2175-
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
2195+
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]:
21762196
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
21772197
trace = DynamicJaxprTrace(fun.debug_info)
21782198
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
21792199
in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
21802200
in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
2181-
with core.set_current_trace(trace):
2182-
ans = fun.call_wrapped(*in_tracers)
2201+
try:
2202+
with core.set_current_trace(trace):
2203+
ans = fun.call_wrapped(*in_tracers)
21832204

2184-
out_tracers = map(trace.to_jaxpr_tracer, ans)
2185-
_check_no_returned_refs(fun.debug_info, out_tracers)
2186-
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info)
2187-
del trace, fun, in_tracers, out_tracers, ans
2205+
out_tracers = map(trace.to_jaxpr_tracer, ans)
2206+
_check_no_returned_refs(fun.debug_info, out_tracers)
2207+
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info)
2208+
del fun, in_tracers, out_tracers, ans
2209+
finally:
2210+
trace.frame.reset_states()
2211+
del trace
21882212

21892213
config.enable_checks.value and core.check_jaxpr(jaxpr)
21902214
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
@@ -2242,14 +2266,14 @@ def trace_to_jaxpr_dynamic2(
22422266
tuple[AbstractedAxisName, ...],
22432267
]
22442268

2245-
AttrsTracked = list[tuple[Any, str]]
2269+
AttrsTracked = list[tuple[Any, str, AttrKind]]
22462270
AttrStates = list
2247-
def set_states(attrs_tracked: AttrsTracked, vals: AttrStates):
2248-
for ((obj, attr), val) in zip(attrs_tracked, vals):
2271+
def reset_states(attrs_tracked: AttrsTracked, init_vals: AttrStates) -> None:
2272+
for ((obj, attr, _), val) in zip(attrs_tracked, init_vals):
22492273
setattr(obj, attr, val) if val is not dne_sentinel else delattr(obj, attr)
22502274

2251-
def get_states(attrs_tracked: AttrsTracked):
2252-
return [getattr(obj, attr) for (obj, attr) in attrs_tracked]
2275+
def get_states(attrs_tracked: AttrsTracked) -> list[PyTree]:
2276+
return [getattr(obj, attr) for (obj, attr, kind) in attrs_tracked]
22532277

22542278
@register_static
22552279
class DoesNotExist: ...

jax/_src/lax/control_flow/loops.py

Lines changed: 64 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,17 @@ def _create_jaxpr(init):
298298
if len(out_tree_children) != 2:
299299
msg = "scan body output must be a pair, got {}."
300300
raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals)))
301-
_, carry_avals_out, _ = split_list(
302-
jaxpr.out_avals, [len(attrs_tracked), out_tree_children[0].num_leaves])
301+
302+
if attrs_tracked:
303+
appends_out = [kind is pe.Append for *_, (_, _, kind) in attrs_tracked]
304+
jaxpr = pe.move_outvars_to_back(
305+
jaxpr, appends_out + [False] * (len(jaxpr.out_avals) - len(appends_out)))
306+
num_attr_carry = sum(init_tree.num_leaves for init_tree, _, (_, _, kind)
307+
in attrs_tracked if kind is pe.ReadWrite)
308+
_, carry_avals_out, _ = split_list(
309+
jaxpr.out_avals, [num_attr_carry, out_tree_children[0].num_leaves])
310+
else:
311+
carry_avals_out, _ = split_list(jaxpr.out_avals, [out_tree_children[0].num_leaves])
303312
return (init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr,
304313
consts, out_tree, out_tree_children, attrs_tracked)
305314

@@ -332,37 +341,59 @@ def _create_jaxpr(init):
332341
raise ValueError("`unroll` must be a `bool` or a positive `int`.")
333342
if attrs_tracked:
334343
in_state = _get_states(attrs_tracked)
335-
in_carry, in_ext = split_list(in_flat, [num_carry])
336-
in_flat = [*in_state, *in_carry, *in_ext]
337-
num_carry += len(attrs_tracked)
344+
in_flat = [*in_state, *in_flat]
345+
num_carry += len(in_state)
338346
out = scan_p.bind(*consts, *in_flat,
339347
reverse=reverse, length=length, jaxpr=jaxpr,
340348
num_consts=len(consts), num_carry=num_carry,
341349
linear=(False,) * (len(consts) + len(in_flat)),
342350
unroll=unroll,
343351
_split_transpose=_split_transpose)
344352
if attrs_tracked:
345-
out_state, out = split_list(out, [len(attrs_tracked)])
346-
_set_states(attrs_tracked, out_state)
353+
num_ext = (len(out) - len(in_state)
354+
- sum(k is pe.Append for *_, (_, _, k) in attrs_tracked))
355+
out_state, out, out_append = split_list(out, [len(in_state), num_ext])
356+
out_attrs = _merge_attrs_out(attrs_tracked, out_state, out_append)
357+
_set_states(attrs_tracked, out_attrs)
347358
return tree_unflatten(out_tree, out)
348359

349360
def _set_states(attrs_tracked, vals):
350-
from jax.experimental.attrs import jax_setattr
361+
from jax.experimental.attrs import jax_setattr, jax_extendattr
351362
valss = split_list_checked(vals, [td.num_leaves for _, td, _ in attrs_tracked])
352-
for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss):
353-
val = tree_unflatten(treedef, leaves)
354-
jax_setattr(obj, attr, val)
363+
for ((_, treedef, (obj, attr, kind)), leaves) in zip(attrs_tracked, valss):
364+
if kind is pe.ReadWrite:
365+
val = tree_unflatten(treedef, leaves)
366+
jax_setattr(obj, attr, val)
367+
elif kind is pe.Append:
368+
val, = leaves
369+
jax_extendattr(obj, attr, val.reshape(-1, *val.shape[2:]))
370+
else:
371+
assert False
355372

356373
def _get_states(attrs_tracked):
357374
from jax.experimental.attrs import jax_getattr
358375
vals = []
359-
for treedef, _, (obj, attr) in attrs_tracked:
360-
tree = jax_getattr(obj, attr)
361-
leaves, treedef_ = tree_flatten(tree)
362-
assert treedef == treedef_
363-
vals.extend(leaves)
376+
for treedef, _, (obj, attr, kind) in attrs_tracked:
377+
if kind is pe.ReadWrite:
378+
tree = jax_getattr(obj, attr)
379+
leaves, treedef_ = tree_flatten(tree)
380+
assert treedef == treedef_
381+
vals.extend(leaves)
382+
elif kind is pe.Append:
383+
pass
384+
else:
385+
assert False
364386
return vals
365387

388+
def _merge_attrs_out(attrs_tracked, out_state, out_append):
389+
out_state_, out_append_ = iter(out_state), iter(out_append)
390+
out_attrs = [item for _, out_tree, (_, _, k) in attrs_tracked for item in
391+
(itertools.islice(out_state_, out_tree.num_leaves)
392+
if k is pe.ReadWrite else [next(out_append_)])]
393+
assert next(out_state_, None) is next(out_append_, None) is None
394+
return out_attrs
395+
396+
366397
def _capitalize(s):
367398
# s.capitalize() converts s[1:] to lowercase which we don't want.
368399
return s[0].capitalize() + s[1:]
@@ -662,7 +693,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,
662693
# The above trace_to_jaxpr_nounits call computed loop-invariant residuals
663694
# (known values in invar_pvals_out) and also computed loop-invariant values
664695
# needed by the new jaxpr_known (in jaxpr_known_consts, which replace the
665-
# previous consts). We need to collect the computed inteisive residuals, and
696+
# previous consts). We need to collect the computed intensive residuals, and
666697
# move corresponding intensive residual binders in jaxpr_unknown to the front.
667698
res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:]
668699
intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()]
@@ -785,16 +816,21 @@ def _scan_transpose(cts, *args, reverse, length, num_consts,
785816
ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts])
786817

787818
# jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b])
788-
# jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a])
819+
# jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a, e])
789820
jaxpr_trans, attrs_tracked = _transpose_scan_jaxpr(
790821
jaxpr, num_ires, num_consts - num_ires, num_eres, ct_ys_is_zeros)
791-
linear_trans = ([False] * num_ires + [False] * len(attrs_tracked) +
822+
appends_out = [kind is pe.Append for *_, (_, _, kind) in attrs_tracked]
823+
jaxpr_trans = pe.move_outvars_to_back(
824+
jaxpr_trans, appends_out + [False] * (len(jaxpr_trans.out_avals) - len(appends_out)))
825+
num_attr_carry = sum(init_tree.num_leaves for init_tree, _, (_, _, kind)
826+
in attrs_tracked if kind is pe.ReadWrite)
827+
linear_trans = ([False] * num_ires + [False] * num_attr_carry +
792828
[True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) +
793829
[False] * num_eres)
794830
in_state = _get_states(attrs_tracked)
795831

796832
transpose_inputs = *ires, *in_state, *ct_consts, *ct_carry, *ct_ys, *eres
797-
transpose_num_out_carry = num_consts-num_ires+num_carry+len(attrs_tracked)
833+
transpose_num_out_carry = num_consts-num_ires+num_carry+num_attr_carry
798834

799835
if not _split_transpose:
800836
outs = scan_p.bind(
@@ -889,8 +925,10 @@ def _scan_transpose(cts, *args, reverse, length, num_consts,
889925
for mask in outs_mask
890926
]
891927

892-
out_state, outs = split_list(outs, [len(attrs_tracked)])
893-
_set_states(attrs_tracked, out_state)
928+
num_outs = len(outs) - num_attr_carry - sum(appends_out)
929+
out_state, outs, out_append = split_list(outs, [num_attr_carry, num_outs])
930+
out_attrs = _merge_attrs_out(attrs_tracked, out_state, out_append)
931+
_set_states(attrs_tracked, out_attrs)
894932
ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry])
895933
return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres
896934

@@ -935,12 +973,10 @@ def transposed(*res1_cbar_bbar_res2):
935973
return c_bar + a_bar
936974

937975
# TODO(necula): fix arg names and results for transposed
938-
transposed_wrapped = lu.wrap_init(transposed,
939-
debug_info=jaxpr.jaxpr.debug_info)
940-
return _make_closed_jaxpr_attrs(
941-
transposed_wrapped,
942-
tuple(res1_avals + c_avals + b_carry_avals +
943-
b_ys_avals_stripped + res2_avals))
976+
transposed_wrapped = lu.wrap_init(transposed, debug_info=jaxpr.jaxpr.debug_info)
977+
trans_avals = (*res1_avals, *c_avals, *b_carry_avals, *b_ys_avals_stripped, *res2_avals)
978+
trans_jaxpr, attrs_tracked = _make_closed_jaxpr_attrs(transposed_wrapped, trans_avals)
979+
return trans_jaxpr, attrs_tracked
944980

945981

946982
def _scan_batching_rule(axis_data, args,

0 commit comments

Comments
 (0)