Skip to content

Commit 1fb2ec4

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Propagate unreduced for transpose and DUS correctly in shard_map
PiperOrigin-RevId: 836311442
1 parent 9504a47 commit 1fb2ec4

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

jax/_src/lax/lax.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7498,7 +7498,12 @@ def _transpose_sharding_rule(operand, *, permutation):
74987498
return operand.sharding.update(spec=o_spec.update(partitions=new_spec))
74997499

75007500
def _transpose_unreduced_rule(out_s, operand, *, permutation):
7501-
return out_s
7501+
return out_s.update(spec=out_s.spec.update(
7502+
unreduced=operand.sharding.spec.unreduced))
7503+
7504+
def _transpose_reduced_rule(out_s, operand, *, permutation):
7505+
return out_s.update(spec=out_s.spec.update(
7506+
reduced=operand.sharding.spec.reduced))
75027507

75037508
def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
75047509
operand, = batched_args
@@ -7524,7 +7529,8 @@ def _transpose_lower(ctx, x, *, permutation):
75247529
_transpose_shape_rule, input_dtype, 'transpose',
75257530
sharding_rule=_transpose_sharding_rule,
75267531
vma_rule=partial(core.standard_vma_rule, 'transpose'),
7527-
unreduced_rule=_transpose_unreduced_rule)
7532+
unreduced_rule=_transpose_unreduced_rule,
7533+
reduced_rule=_transpose_reduced_rule)
75287534
ad.deflinear2(transpose_p,
75297535
lambda t, _, permutation: [transpose(t, np.argsort(permutation))])
75307536
batching.primitive_batchers[transpose_p] = _transpose_batch_rule

jax/_src/lax/slicing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1723,7 +1723,8 @@ def _dynamic_update_slice_unreduced_rule(out_s, operand, update, *start_indices)
17231723
" same axes. Got operand sharding"
17241724
f" {operand.str_short(mesh_axis_types=True)} and update sharding"
17251725
f" {update.str_short(mesh_axis_types=True)}.")
1726-
return out_s
1726+
return out_s.update(spec=out_s.spec.update(
1727+
unreduced=operand.sharding.spec.unreduced))
17271728

17281729
def _dynamic_update_slice_reduced_rule(out_s, operand, update, *start_indices):
17291730
if operand.sharding.spec.reduced != update.sharding.spec.reduced:

tests/shard_map_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4583,6 +4583,29 @@ def loss_fn(inputs, params, targets):
45834583

45844584
jax.jit(jax.grad(loss_fn, argnums=1))(inputs, params, targets) # doesn't crash
45854585

4586+
@jtu.with_explicit_mesh((2,), 'x')
4587+
def test_transpose_unreduced_shmap(self, mesh):
4588+
arr1 = jax.device_put(np.arange(8.).reshape(2, 4), P(reduced={'x'}))
4589+
arr2 = jax.device_put(np.arange(12.).reshape(2, 6), P(None, 'x'))
4590+
4591+
@jax.shard_map(out_specs=P(None, 'x'))
4592+
def f(x, y):
4593+
x_ = x.T
4594+
return jnp.dot(x_, y)
4595+
4596+
@jax.jit
4597+
def g(x, y):
4598+
return f(x, y).sum()
4599+
4600+
out = g(arr1, arr2)
4601+
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
4602+
self.assertArraysEqual(out, (arr1.T @ arr2).sum())
4603+
4604+
out1, out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr2)
4605+
self.assertEqual(out1.sharding,
4606+
NamedSharding(mesh, P(None, None, unreduced={'x'})))
4607+
self.assertEqual(out2.sharding, arr2.sharding)
4608+
45864609

45874610
class FunSpec(NamedTuple):
45884611
name: str

0 commit comments

Comments
 (0)