@@ -1161,7 +1161,7 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,
11611161 f'called with sequence whose items have type\n { _avals_short (x_avals_mapped )} ' )
11621162 return [* init_avals , * y_avals ], jaxpr .effects
11631163
1164- def _scan_state_discharge_rule ( in_avals , out_avals , * args , jaxpr , num_consts ,
1164+ def _scan_state_partial_discharge_rule ( should_discharge , in_avals , out_avals , * args , jaxpr , num_consts ,
11651165 num_carry , linear , unroll , reverse , length ,
11661166 _split_transpose ):
11671167 # We're shuffling parameters between three signatures for the scan body:
@@ -1182,39 +1182,59 @@ def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
11821182 n_ys = len (out_avals ) - n_carry
11831183 consts_avals , carry_avals , xs_avals = split_list_checked (in_avals ,
11841184 [n_consts , n_carry , n_xs ])
1185- is_ref_const = [isinstance (a , state .AbstractRef ) for a in consts_avals ]
1186- assert not any (isinstance (a , state .AbstractRef ) for a in carry_avals )
1187- is_ref_xs = [isinstance (a , state .AbstractRef ) for a in xs_avals ]
1185+ consts_discharge , carry_discharge , xs_discharge = split_list_checked (should_discharge ,
1186+ [n_consts , n_carry , n_xs ])
1187+
1188+ is_ref_const = [s and isinstance (a , state .AbstractRef ) for s , a in zip (consts_discharge , consts_avals )]
1189+ assert not any (isinstance (a , state .AbstractRef ) for a in carry_avals )
1190+ assert not any (carry_discharge )
1191+ is_ref_xs = [s and isinstance (a , state .AbstractRef ) for s , a in zip (xs_discharge , xs_avals )]
11881192 n_ref_consts = sum (is_ref_const )
11891193 n_val_consts = n_consts - n_ref_consts
11901194 n_ref_xs = sum (is_ref_xs )
11911195 n_val_xs = n_xs - n_ref_xs
1192- discharged_jaxpr , discharged_consts = state_discharge .discharge_state (jaxpr , ())
1196+ discharged_jaxpr , discharged_consts = state_discharge .discharge_state (jaxpr , (), should_discharge = should_discharge )
11931197 if discharged_consts :
11941198 raise NotImplementedError ("Discharged jaxpr has consts. If you see this, "
11951199 "please open an issue at "
11961200 "https://github.com/jax-ml/jax/issues" )
11971201 def wrapped (* wrapped_args ):
1198- val_consts , ref_consts_in , carry_in , val_xs , ref_xs_in = split_list_checked (wrapped_args ,
1199- [n_val_consts , n_ref_consts , n_carry , n_val_xs , n_ref_xs ])
1202+ val_consts , carry_in , ref_consts_in , val_xs , ref_xs_in = split_list_checked (wrapped_args ,
1203+ [n_val_consts , n_carry , n_ref_consts , n_val_xs , n_ref_xs ])
12001204 consts = merge_lists (is_ref_const , val_consts , ref_consts_in )
12011205 xs = merge_lists (is_ref_xs , val_xs , ref_xs_in )
12021206 outs = core .eval_jaxpr (discharged_jaxpr , (), * consts , * carry_in , * xs )
12031207 carry_out , ys , ref_consts_out , ref_xs_out = split_list_checked (outs ,
12041208 [n_carry , n_ys , n_ref_consts , n_ref_xs ])
1205- return [* ref_consts_out , * carry_out , * ys , * ref_xs_out ]
1209+ return [* carry_out , * ref_consts_out , * ys , * ref_xs_out ]
12061210
12071211 def arrange_jaxpr_args_for_wrapped (args ):
12081212 consts , carry_in , xs = split_list_checked (args , [n_consts , n_carry , n_xs ])
12091213 val_consts , ref_consts_in = partition_list (is_ref_const , consts )
12101214 val_xs , ref_xs_in = partition_list (is_ref_xs , xs )
1211- return * val_consts , * ref_consts_in , * carry_in , * val_xs , * ref_xs_in
1215+ return * val_consts , * carry_in , * ref_consts_in , * val_xs , * ref_xs_in
12121216
1217+ # Rearrange the arguments such that they are:
1218+ # val_consts, carry, ref_consts, val_xs, ref_xs
1219+ #
1220+ # It is important that carry is immediately after the val_consts
1221+ # because pallas pattern matches the leading argument type to figure
1222+ # out if a scan_p eqn is equivalent to a fori loop (see
1223+ # `pallas.utils.pattern_match_scan_to_fori_loop()`).
12131224 args_for_wrapped = arrange_jaxpr_args_for_wrapped (args )
12141225 linear_for_wrapped = arrange_jaxpr_args_for_wrapped (linear )
12151226 avals_for_wrapped = arrange_jaxpr_args_for_wrapped (in_avals )
1216- avals_for_wrapped_no_refs = [aval .inner_aval if isinstance (aval , state .AbstractRef ) else aval
1217- for aval in avals_for_wrapped ]
1227+ # Get the const avals that we need to discharge and leave the rest as-is.
1228+ deref_const_avals = tuple (c .inner_aval for c in avals_for_wrapped [n_val_consts + n_carry :n_consts + n_carry ])
1229+ deref_xs_avals = tuple (x .inner_aval for x in avals_for_wrapped [n_consts + n_carry + n_val_xs :])
1230+ avals_for_wrapped_no_refs = (
1231+ avals_for_wrapped [: n_val_consts + n_carry ]
1232+ + deref_const_avals
1233+ + avals_for_wrapped [n_consts + n_carry :n_consts + n_carry + n_val_xs ]
1234+ + deref_xs_avals
1235+ )
1236+ # TODO(cperivol): avoid tracing the jaxpr twice. When doing so don't
1237+ # forget to manage the effects.
12181238 new_jaxpr , _ , (), () = pe .trace_to_jaxpr_dynamic (lu .wrap_init (wrapped ), avals_for_wrapped_no_refs )
12191239 all_out = scan_p .bind (* args_for_wrapped ,
12201240 jaxpr = core .ClosedJaxpr (new_jaxpr , ()),
@@ -1224,8 +1244,8 @@ def arrange_jaxpr_args_for_wrapped(args):
12241244 unroll = unroll ,
12251245 reverse = reverse ,
12261246 linear = linear_for_wrapped , _split_transpose = _split_transpose )
1227- ref_consts_out , carry_out , ys , ref_xs_out = split_list_checked (all_out ,
1228- [n_ref_consts , n_carry , n_ys , n_ref_xs ])
1247+ carry_out , ref_consts_out , ys , ref_xs_out = split_list_checked (all_out ,
1248+ [n_carry , n_ref_consts , n_ys , n_ref_xs ])
12291249 refs_out_matching_in_avals = [
12301250 * merge_lists (is_ref_const , [None ] * n_val_consts , ref_consts_out ),
12311251 * [None ] * n_carry ,
@@ -1248,7 +1268,7 @@ def arrange_jaxpr_args_for_wrapped(args):
12481268pe .partial_eval_jaxpr_custom_rules [scan_p ] = _scan_partial_eval_custom
12491269pe .padding_rules [scan_p ] = _scan_padding_rule
12501270pe .dce_rules [scan_p ] = _scan_dce_rule
1251- state_discharge .register_discharge_rule (scan_p )(_scan_state_discharge_rule )
1271+ state_discharge .register_partial_discharge_rule (scan_p )(_scan_state_partial_discharge_rule )
12521272
12531273def _propagate_mem_kind_scan (* xm , reverse , length , num_consts , num_carry , jaxpr ,
12541274 linear , unroll , _split_transpose ):
0 commit comments