Skip to content

Commit 9b2ebc0

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix reduce_sum_transpose_rule which does a broadcast_in_dim to set out_sharding=operand.aval.to_cotangent_aval().sharding instead of operand.aval.sharding. THis is because if operand is reduced, then on bwd pass, we want the cotangent type to become unreduced.
PiperOrigin-RevId: 834511302
1 parent b37b6c0 commit 9b2ebc0

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

jax/_src/lax/lax.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7824,8 +7824,9 @@ def _reduce_sum_transpose_rule(cotangent, operand, *, axes, out_sharding):
78247824
assert ad.is_undefined_primal(operand)
78257825
input_shape = operand.aval.shape
78267826
broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes))
7827-
result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions,
7828-
out_sharding=operand.aval.sharding)
7827+
result = broadcast_in_dim(
7828+
cotangent, input_shape, broadcast_dimensions,
7829+
out_sharding=operand.aval.to_cotangent_aval().sharding)
78297830
assert result.shape == input_shape
78307831
return [result]
78317832

tests/pjit_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9631,6 +9631,22 @@ def test_jnp_repeat_arraylike(self, mesh):
96319631
jnp.repeat(positions, 5, axis=0, total_repeat_length=num_electrons,
96329632
out_sharding=P()) # doesn't crash
96339633

9634+
@jtu.with_explicit_mesh((2,), 'x')
9635+
def test_mul_inputs_both_reduced(self, mesh):
9636+
arr1 = jax.device_put(np.arange(8.), P(reduced={'x'}))
9637+
arr2 = jax.device_put(np.arange(8.), P(reduced={'x'}))
9638+
9639+
@jax.jit
9640+
def f(x, y):
9641+
z = x * y
9642+
return z.sum()
9643+
9644+
out1, out2 = jax.jit(jax.grad(f, argnums=(0, 1)))(arr1, arr2)
9645+
self.assertEqual(out1.sharding,
9646+
NamedSharding(mesh, P(None, unreduced={'x'})))
9647+
self.assertEqual(out2.sharding,
9648+
NamedSharding(mesh, P(None, unreduced={'x'})))
9649+
96349650

96359651
@jtu.pytest_mark_if_available('multiaccelerator')
96369652
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)