@@ -639,6 +639,26 @@ def f(a_ref):
639639 refval , = core .eval_jaxpr (discharged_jaxpr , discharged_consts , inval )
640640 self .assertTrue ((refval == inval .at [jnp .array ([0 , 1 ])].set (1. )).all ())
641641
642+ def test_discharge_swap (self ):
643+ def f (a_ref ):
644+ a = ref_swap (
645+ a_ref .at [0 :4 , 0 :3 , 0 :2 ].at [1 :3 , :, 0 ],
646+ (slice (None ), slice (1 , 3 )),
647+ jnp .zeros ((2 , 2 ), jnp .float32 ))
648+ return [a + 1 ]
649+ in_avals = [shaped_array_ref ((4 , 3 , 2 ), jnp .float32 )]
650+ stateful_jaxpr , _ , consts , () = pe .trace_to_jaxpr_dynamic (
651+ lu .wrap_init (f ), in_avals )
652+
653+ discharged_jaxpr , () = discharge_state (stateful_jaxpr , consts )
654+ self .assertLen (discharged_jaxpr .invars , 1 )
655+ self .assertLen (discharged_jaxpr .outvars , 2 )
656+
657+ inval = jnp .arange (24. , dtype = jnp .float32 ).reshape ((4 , 3 , 2 ))
658+ outval , refval = core .eval_jaxpr (discharged_jaxpr , (), inval )
659+ self .assertArraysEqual (outval , inval [1 :3 , 1 :3 , 0 ] + 1 )
660+ self .assertArraysEqual (refval , inval .at [1 :3 , 1 :3 , 0 ].set (0 ))
661+
642662 def test_discharge_addupdate (self ):
643663 def f (a_ref , b ):
644664 ref_addupdate (a_ref , (), b + 1 )
0 commit comments