Skip to content

Commit c0811c9

Browse files
jkr26Google-ML-Automation
authored andcommitted
Adds coverage for spmd-axisname-filtering in shard_map transpose.
PiperOrigin-RevId: 699193349
1 parent 34a2f0c commit c0811c9

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

tests/shard_map_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,26 @@ def f(x):
709709
self.assertIn('out_names', e.params)
710710
self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},))
711711

712+
def test_vmap_of_grad_spmd_axis_name(self):
713+
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
714+
715+
@partial(
716+
shard_map, mesh=mesh, in_specs=P('y'), out_specs=P(), check_rep=False
717+
)
718+
def f(x):
719+
return jnp.sin(jnp.sum(x))
720+
721+
x = jnp.arange(4 * 4, dtype=jnp.float32).reshape(4, 4)
722+
put_x = jax.device_put(
723+
x,
724+
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')),
725+
)
726+
vmap_spmd_axisname_result = jax.vmap(jax.grad(f), spmd_axis_name='x')(put_x)
727+
vmap_no_spmd_axisname_result = jax.vmap(jax.grad(f))(put_x)
728+
self.assertArraysEqual(
729+
vmap_spmd_axisname_result, vmap_no_spmd_axisname_result
730+
)
731+
712732
def test_vmap_spmd_axis_name_pair(self):
713733
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
714734

0 commit comments

Comments
 (0)