@@ -919,13 +919,7 @@ def _rearrange_mutable_binders(
919919 immut_names , mut_names = partition_list (is_mutable , names )
920920 new_arg_names = [* fst , * mut_names , * immut_names , * rst ]
921921 dbg = jaxpr .jaxpr .debug_info ._replace (arg_names = new_arg_names )
922-
923- # TODO(mattjj): don't we need to re-number effects? test coverage?
924- new_effs = pe ._renumber_effects ((* jaxpr .jaxpr .constvars , * new_invars ),
925- (* jaxpr .jaxpr .constvars , * jaxpr .jaxpr .invars ),
926- jaxpr .jaxpr .effects )
927- new_jaxpr = jaxpr .jaxpr .replace (invars = new_invars , effects = new_effs ,
928- debug_info = dbg )
922+ new_jaxpr = jaxpr .jaxpr .replace (invars = new_invars , debug_info = dbg )
929923 if config .enable_checks .value : core .check_jaxpr (new_jaxpr )
930924 return ClosedJaxpr (new_jaxpr , jaxpr .consts )
931925
@@ -1777,16 +1771,8 @@ def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts
17771771 ) -> effects .Effects :
17781772 joined_effects = set ()
17791773 for eff in cond_jaxpr .effects :
1780- if isinstance (eff , effects .JaxprInputEffect ):
1781- index = eff .input_index
1782- if index >= cond_nconsts :
1783- index += body_nconsts
1784- eff = eff .replace (input_index = index )
17851774 joined_effects .add (eff )
17861775 for eff in body_jaxpr .effects :
1787- if isinstance (eff , effects .JaxprInputEffect ):
1788- index = eff .input_index + cond_nconsts
1789- eff = eff .replace (input_index = index )
17901776 joined_effects .add (eff )
17911777 return joined_effects
17921778
@@ -2285,17 +2271,6 @@ def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args,
22852271 [cond_nconsts ,
22862272 body_nconsts ])
22872273
2288- # Check if the same Ref is written to in both cond and body.
2289- cond_write_ids = {id (cond_consts_avals [effect .input_index ])
2290- for effect in cond_jaxpr .effects if isinstance (effect , state .WriteEffect )}
2291- cond_has_writes = len (cond_write_ids ) > 0
2292- body_write_ids = {id (body_consts_avals [effect .input_index ])
2293- for effect in body_jaxpr .effects if isinstance (effect , state .WriteEffect )}
2294- write_to_both_ids = cond_write_ids & body_write_ids
2295- if write_to_both_ids :
2296- raise NotImplementedError (
2297- "Cannot write to the same ref in both cond and body of while loop." )
2298-
22992274 cond_is_ref = [
23002275 isinstance (aval , state .AbstractRef ) and should
23012276 for aval , should in zip (cond_consts_avals , cond_consts_discharge )
@@ -2315,13 +2290,6 @@ def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args,
23152290 num_body_refs = sum (body_is_ref )
23162291 num_remaining_body_consts = body_nconsts - num_body_refs
23172292 num_out_body_consts = num_remaining_body_consts
2318- if cond_has_writes :
2319- # If the cond has writes, we need to add the cond consts into the body
2320- # consts since we need to evaluate the cond condition in the body.
2321- remaining_body_consts = [* remaining_cond_consts , * remaining_body_consts ]
2322- remaining_body_const_avals = [* remaining_cond_const_avals ,
2323- * remaining_body_const_avals ]
2324- num_remaining_body_consts += num_remaining_cond_consts
23252293
23262294 num_carry = len (in_avals ) - body_nconsts - cond_nconsts
23272295 body_jaxpr , body_jaxpr_consts = body_jaxpr .jaxpr , body_jaxpr .consts
@@ -2353,23 +2321,8 @@ def new_body(*consts_refs_carry):
23532321 consts , body_refs , cond_refs , carry = split_list (
23542322 consts_refs_carry ,
23552323 [num_remaining_body_consts , num_body_refs , num_cond_refs ])
2356- if cond_has_writes :
2357- # We run the cond jaxpr in the body so that Refs that are updated
2358- # in the cond jaxpr are persisted via the carry.
2359- cond_consts , body_consts = split_list (consts , [num_remaining_cond_consts ])
2360- cond_consts_and_refs = merge_lists (cond_is_ref , cond_consts , cond_refs )
2361- cond_carry_refs = core .eval_jaxpr (discharged_cond_jaxpr , (),
2362- * cond_consts_and_refs ,
2363- * carry )
2364- # Note: in order to handle the same Ref being updated in both the cond
2365- # and body, we would need to interleave the updated cond_carry_refs into
2366- # body_refs here.
2367- # Currently we disallow this so we don't need to handle it.
2368- _ , cond_refs_out = split_list (cond_carry_refs , [1 ])
2369- assert len (cond_refs_out ) == len (cond_refs )
2370- else :
2371- body_consts = consts
2372- cond_refs_out = cond_refs
2324+ body_consts = consts
2325+ cond_refs_out = cond_refs
23732326
23742327 body_consts_and_refs = merge_lists (body_is_ref , body_consts , body_refs )
23752328 body_carry_refs = core .eval_jaxpr (discharged_body_jaxpr , (),
0 commit comments