Skip to content

Commit 72df8e0

Browse files
Merge pull request jax-ml#25205 from jburnim:jburnim_swap_fix
PiperOrigin-RevId: 703541711
2 parents eda7506 + af50135 commit 72df8e0

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

jax/_src/state/discharge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def transform_swap_array(x, transforms, val):
364364
case indexing.NDIndexer():
365365
indexer = transform
366366
if _is_trivial_indexer(indexer):
367-
_results.append(None)
367+
_results.append(_results[-1])
368368
continue
369369
# If everything in the indexer is a slice or ()-shaped, we can also
370370
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.

tests/state_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,26 @@ def f(a_ref):
639639
refval, = core.eval_jaxpr(discharged_jaxpr, discharged_consts, inval)
640640
self.assertTrue((refval == inval.at[jnp.array([0, 1])].set(1.)).all())
641641

642+
def test_discharge_swap(self):
643+
def f(a_ref):
644+
a = ref_swap(
645+
a_ref.at[0:4, 0:3, 0:2].at[1:3, :, 0],
646+
(slice(None), slice(1, 3)),
647+
jnp.zeros((2, 2), jnp.float32))
648+
return [a + 1]
649+
in_avals = [shaped_array_ref((4, 3, 2), jnp.float32)]
650+
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
651+
lu.wrap_init(f), in_avals)
652+
653+
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
654+
self.assertLen(discharged_jaxpr.invars, 1)
655+
self.assertLen(discharged_jaxpr.outvars, 2)
656+
657+
inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2))
658+
outval, refval = core.eval_jaxpr(discharged_jaxpr, (), inval)
659+
self.assertArraysEqual(outval, inval[1:3, 1:3, 0] + 1)
660+
self.assertArraysEqual(refval, inval.at[1:3, 1:3, 0].set(0))
661+
642662
def test_discharge_addupdate(self):
643663
def f(a_ref, b):
644664
ref_addupdate(a_ref, (), b + 1)

0 commit comments

Comments
 (0)