Skip to content

Commit 944d822

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add a no-op batching rule for optimization_barrier_p
PiperOrigin-RevId: 704507586
1 parent 1743f2c commit 944d822

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

jax/_src/lax/lax.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6496,3 +6496,7 @@ def _optimization_barrier_lowering_rule(ctx, *args):
64966496
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
64976497
mlir.register_lowering(optimization_barrier_p,
64986498
_optimization_barrier_lowering_rule)
6499+
6500+
def _optimization_barrier_batcher(batched_args, batch_dims, **params):
6501+
return optimization_barrier_p.bind(*batched_args, **params), batch_dims
6502+
batching.primitive_batchers[optimization_barrier_p] = _optimization_barrier_batcher

tests/lax_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3551,7 +3551,7 @@ def testAsarray(self, typ):
35513551
with jax.transfer_guard('disallow'):
35523552
jax.jit(asarray_closure)()
35533553

3554-
def testOptimizationBarrier(self):
3554+
def test_optimization_barrier(self):
35553555
x = lax.optimization_barrier((2, 3))
35563556
self.assertEqual((2, 3), x)
35573557

tests/lax_vmap_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,25 @@ def testTopK(self, shape, dtype, k, bdims):
691691
op2 = lambda x: lax.top_k(x, k=k)[1]
692692
self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng)
693693

694+
@jtu.sample_product(
695+
[dict(shape=shape, bdims=bdims)
696+
for shape in [(8,), (3, 4, 5)]
697+
for bdims in lax_test_util.all_bdims(shape)],
698+
dtype=lax_test_util.default_dtypes,
699+
)
700+
def test_optimization_barrier_vmap(self, shape, dtype, bdims):
701+
rng = jtu.rand_small(self.rng())
702+
self._CheckBatching(lax.optimization_barrier, 5, bdims, (shape,), (dtype,),
703+
rng)
704+
705+
def test_optimization_barrier_vmap_out_axes(self):
706+
x = jnp.arange(8)
707+
y = x.reshape(1, 8)
708+
out = jax.vmap(lax.optimization_barrier, in_axes=((0, 1),),
709+
out_axes=(0, 1))((x, y))
710+
self.assertArraysEqual(out[0], x)
711+
self.assertArraysEqual(out[1], y)
712+
694713
@jtu.sample_product(
695714
[dict(shape=shape, bdims=bdims, dimension=dimension, arity=arity)
696715
for shape in [(2, 3)]

0 commit comments

Comments
 (0)