@@ -298,8 +298,17 @@ def _create_jaxpr(init):
298298 if len (out_tree_children ) != 2 :
299299 msg = "scan body output must be a pair, got {}."
300300 raise TypeError (msg .format (tree_unflatten (out_tree , jaxpr .out_avals )))
301- _ , carry_avals_out , _ = split_list (
302- jaxpr .out_avals , [len (attrs_tracked ), out_tree_children [0 ].num_leaves ])
301+
302+ if attrs_tracked :
303+ appends_out = [kind is pe .Append for * _ , (_ , _ , kind ) in attrs_tracked ]
304+ jaxpr = pe .move_outvars_to_back (
305+ jaxpr , appends_out + [False ] * (len (jaxpr .out_avals ) - len (appends_out )))
306+ num_attr_carry = sum (init_tree .num_leaves for init_tree , _ , (_ , _ , kind )
307+ in attrs_tracked if kind is pe .ReadWrite )
308+ _ , carry_avals_out , _ = split_list (
309+ jaxpr .out_avals , [num_attr_carry , out_tree_children [0 ].num_leaves ])
310+ else :
311+ carry_avals_out , _ = split_list (jaxpr .out_avals , [out_tree_children [0 ].num_leaves ])
303312 return (init_flat , carry_avals , carry_avals_out , init_tree , in_flat , jaxpr ,
304313 consts , out_tree , out_tree_children , attrs_tracked )
305314
@@ -332,37 +341,59 @@ def _create_jaxpr(init):
332341 raise ValueError ("`unroll` must be a `bool` or a positive `int`." )
333342 if attrs_tracked :
334343 in_state = _get_states (attrs_tracked )
335- in_carry , in_ext = split_list (in_flat , [num_carry ])
336- in_flat = [* in_state , * in_carry , * in_ext ]
337- num_carry += len (attrs_tracked )
344+ in_flat = [* in_state , * in_flat ]
345+ num_carry += len (in_state )
338346 out = scan_p .bind (* consts , * in_flat ,
339347 reverse = reverse , length = length , jaxpr = jaxpr ,
340348 num_consts = len (consts ), num_carry = num_carry ,
341349 linear = (False ,) * (len (consts ) + len (in_flat )),
342350 unroll = unroll ,
343351 _split_transpose = _split_transpose )
344352 if attrs_tracked :
345- out_state , out = split_list (out , [len (attrs_tracked )])
346- _set_states (attrs_tracked , out_state )
353+ num_ext = (len (out ) - len (in_state )
354+ - sum (k is pe .Append for * _ , (_ , _ , k ) in attrs_tracked ))
355+ out_state , out , out_append = split_list (out , [len (in_state ), num_ext ])
356+ out_attrs = _merge_attrs_out (attrs_tracked , out_state , out_append )
357+ _set_states (attrs_tracked , out_attrs )
347358 return tree_unflatten (out_tree , out )
348359
349360def _set_states (attrs_tracked , vals ):
350- from jax .experimental .attrs import jax_setattr
361+ from jax .experimental .attrs import jax_setattr , jax_extendattr
351362 valss = split_list_checked (vals , [td .num_leaves for _ , td , _ in attrs_tracked ])
352- for ((_ , treedef , (obj , attr )), leaves ) in zip (attrs_tracked , valss ):
353- val = tree_unflatten (treedef , leaves )
354- jax_setattr (obj , attr , val )
363+ for ((_ , treedef , (obj , attr , kind )), leaves ) in zip (attrs_tracked , valss ):
364+ if kind is pe .ReadWrite :
365+ val = tree_unflatten (treedef , leaves )
366+ jax_setattr (obj , attr , val )
367+ elif kind is pe .Append :
368+ val , = leaves
369+ jax_extendattr (obj , attr , val .reshape (- 1 , * val .shape [2 :]))
370+ else :
371+ assert False
355372
356373def _get_states (attrs_tracked ):
357374 from jax .experimental .attrs import jax_getattr
358375 vals = []
359- for treedef , _ , (obj , attr ) in attrs_tracked :
360- tree = jax_getattr (obj , attr )
361- leaves , treedef_ = tree_flatten (tree )
362- assert treedef == treedef_
363- vals .extend (leaves )
376+ for treedef , _ , (obj , attr , kind ) in attrs_tracked :
377+ if kind is pe .ReadWrite :
378+ tree = jax_getattr (obj , attr )
379+ leaves , treedef_ = tree_flatten (tree )
380+ assert treedef == treedef_
381+ vals .extend (leaves )
382+ elif kind is pe .Append :
383+ pass
384+ else :
385+ assert False
364386 return vals
365387
388+ def _merge_attrs_out (attrs_tracked , out_state , out_append ):
389+ out_state_ , out_append_ = iter (out_state ), iter (out_append )
390+ out_attrs = [item for _ , out_tree , (_ , _ , k ) in attrs_tracked for item in
391+ (itertools .islice (out_state_ , out_tree .num_leaves )
392+ if k is pe .ReadWrite else [next (out_append_ )])]
393+ assert next (out_state_ , None ) is next (out_append_ , None ) is None
394+ return out_attrs
395+
396+
366397def _capitalize (s ):
367398 # s.capitalize() converts s[1:] to lowercase which we don't want.
368399 return s [0 ].capitalize () + s [1 :]
@@ -662,7 +693,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,
662693 # The above trace_to_jaxpr_nounits call computed loop-invariant residuals
663694 # (known values in invar_pvals_out) and also computed loop-invariant values
664695 # needed by the new jaxpr_known (in jaxpr_known_consts, which replace the
665- # previous consts). We need to collect the computed inteisive residuals, and
696+ # previous consts). We need to collect the computed intensive residuals, and
666697 # move corresponding intensive residual binders in jaxpr_unknown to the front.
667698 res_pvals = invar_pvals_out [len (invar_pvals_out ) - num_res :]
668699 intensive_res = [pval .get_known () for pval in res_pvals if pval .is_known ()]
@@ -785,16 +816,21 @@ def _scan_transpose(cts, *args, reverse, length, num_consts,
785816 ct_consts = _map (ad_util .zeros_like_aval , jaxpr .in_avals [num_ires :num_consts ])
786817
787818 # jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b])
788- # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a])
819+ # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a, e ])
789820 jaxpr_trans , attrs_tracked = _transpose_scan_jaxpr (
790821 jaxpr , num_ires , num_consts - num_ires , num_eres , ct_ys_is_zeros )
791- linear_trans = ([False ] * num_ires + [False ] * len (attrs_tracked ) +
822+ appends_out = [kind is pe .Append for * _ , (_ , _ , kind ) in attrs_tracked ]
823+ jaxpr_trans = pe .move_outvars_to_back (
824+ jaxpr_trans , appends_out + [False ] * (len (jaxpr_trans .out_avals ) - len (appends_out )))
825+ num_attr_carry = sum (init_tree .num_leaves for init_tree , _ , (_ , _ , kind )
826+ in attrs_tracked if kind is pe .ReadWrite )
827+ linear_trans = ([False ] * num_ires + [False ] * num_attr_carry +
792828 [True ] * (len (ct_consts ) + len (ct_carry ) + len (ct_ys )) +
793829 [False ] * num_eres )
794830 in_state = _get_states (attrs_tracked )
795831
796832 transpose_inputs = * ires , * in_state , * ct_consts , * ct_carry , * ct_ys , * eres
797- transpose_num_out_carry = num_consts - num_ires + num_carry + len ( attrs_tracked )
833+ transpose_num_out_carry = num_consts - num_ires + num_carry + num_attr_carry
798834
799835 if not _split_transpose :
800836 outs = scan_p .bind (
@@ -889,8 +925,10 @@ def _scan_transpose(cts, *args, reverse, length, num_consts,
889925 for mask in outs_mask
890926 ]
891927
892- out_state , outs = split_list (outs , [len (attrs_tracked )])
893- _set_states (attrs_tracked , out_state )
928+ num_outs = len (outs ) - num_attr_carry - sum (appends_out )
929+ out_state , outs , out_append = split_list (outs , [num_attr_carry , num_outs ])
930+ out_attrs = _merge_attrs_out (attrs_tracked , out_state , out_append )
931+ _set_states (attrs_tracked , out_attrs )
894932 ct_consts , ct_init , ct_xs = split_list (outs , [num_consts - num_ires , num_carry ])
895933 return [None ] * num_ires + ct_consts + ct_init + ct_xs + [None ] * num_eres
896934
@@ -935,12 +973,10 @@ def transposed(*res1_cbar_bbar_res2):
935973 return c_bar + a_bar
936974
937975 # TODO(necula): fix arg names and results for transposed
938- transposed_wrapped = lu .wrap_init (transposed ,
939- debug_info = jaxpr .jaxpr .debug_info )
940- return _make_closed_jaxpr_attrs (
941- transposed_wrapped ,
942- tuple (res1_avals + c_avals + b_carry_avals +
943- b_ys_avals_stripped + res2_avals ))
976+ transposed_wrapped = lu .wrap_init (transposed , debug_info = jaxpr .jaxpr .debug_info )
977+ trans_avals = (* res1_avals , * c_avals , * b_carry_avals , * b_ys_avals_stripped , * res2_avals )
978+ trans_jaxpr , attrs_tracked = _make_closed_jaxpr_attrs (transposed_wrapped , trans_avals )
979+ return trans_jaxpr , attrs_tracked
944980
945981
946982def _scan_batching_rule (axis_data , args ,
0 commit comments