Skip to content

Commit 6568713

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Add concatenate_p support
PiperOrigin-RevId: 698621325
1 parent 869a533 commit 6568713

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

jax/_src/lax/lax.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,6 +1700,8 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array:
17001700
scalar_zero = np.zeros((), dtype=aval.dtype)
17011701
else:
17021702
scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type)
1703+
if config.sharding_in_types.value:
1704+
return broadcast(scalar_zero, aval.shape, sharding=aval.sharding)
17031705
return broadcast(scalar_zero, aval.shape)
17041706

17051707
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
@@ -4401,7 +4403,7 @@ def _concatenate_shape_rule(*operands, **kwargs):
44014403
raise TypeError(msg.format(dimension, ", ".join([str(o.shape) for o in operands])))
44024404
shapes = [operand.shape[:dimension] + operand.shape[dimension+1:]
44034405
for operand in operands]
4404-
if not shapes[:-1] == shapes[1:]:
4406+
if shapes[:-1] != shapes[1:]:
44054407
msg = ("Cannot concatenate arrays with shapes that differ in dimensions "
44064408
"other than the one being concatenated: concatenating along "
44074409
"dimension {} for shapes {}.")
@@ -4412,6 +4414,13 @@ def _concatenate_shape_rule(*operands, **kwargs):
44124414
ex_shape = operands[0].shape
44134415
return ex_shape[:dimension] + (concat_size,) + ex_shape[dimension+1:]
44144416

4417+
def _concatenate_sharding_rule(*operands, **kwargs):
4418+
if not all(o.sharding == operands[0].sharding for o in operands):
4419+
ss = ", ".join(str(o.sharding) for o in operands)
4420+
raise TypeError(
4421+
f"All operands should have the same sharding. Got shardings {ss}")
4422+
return operands[0].sharding
4423+
44154424
def _concatenate_dtype_rule(*operands, **kwargs):
44164425
check_same_dtypes('concatenate', *operands)
44174426
return operands[0].dtype
@@ -4452,14 +4461,19 @@ def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension):
44524461
raise NotImplementedError # TODO(mattjj)
44534462

44544463
concatenate_p = standard_primitive(
4455-
_concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate')
4464+
_concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate',
4465+
sharding_rule=_concatenate_sharding_rule)
44564466
ad.deflinear2(concatenate_p, _concatenate_transpose_rule)
44574467
ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule
44584468
batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule
44594469
pe.padding_rules[concatenate_p] = _concatenate_pad_rule
44604470

44614471
def _concatenate_lower(ctx, *xs, dimension):
4462-
return [hlo.concatenate(xs, mlir.i64_attr(dimension))]
4472+
aval_out, = ctx.avals_out
4473+
out = hlo.concatenate(xs, mlir.i64_attr(dimension))
4474+
if config.sharding_in_types.value:
4475+
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
4476+
return [out]
44634477
mlir.register_lowering(concatenate_p, _concatenate_lower)
44644478

44654479

tests/pjit_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5384,6 +5384,48 @@ def g(x):
53845384
arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y'))))
53855385
f(arr, ((4, 4, 1),), None)
53865386

5387+
def test_concatenate(self):
5388+
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
5389+
np_inp = np.arange(16.).reshape(4, 4)
5390+
s = NamedSharding(mesh, P('x', 'y'))
5391+
arr1 = jax.device_put(np_inp, s)
5392+
arr2 = jax.device_put(np.arange(4.).reshape(4, 1), s)
5393+
5394+
@partial(jax.jit, static_argnums=2)
5395+
def f(x, y, method='jnp'):
5396+
if method == 'jnp':
5397+
y = jnp.concatenate([x, y], axis=1)
5398+
else:
5399+
assert method == 'lax'
5400+
y = lax.concatenate([x, y], dimension=1)
5401+
self.assertEqual(y.sharding.spec, P('x', 'y'))
5402+
return y
5403+
5404+
out = f(arr1, arr2)
5405+
self.assertEqual(out.sharding, s)
5406+
self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1))
5407+
self.assertIn('@Sharding', f.lower(arr1, arr2).as_text())
5408+
5409+
out = f(arr1, arr2, method='lax')
5410+
self.assertEqual(out.sharding, s)
5411+
self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1))
5412+
5413+
with self.assertRaisesRegex(
5414+
TypeError, "All operands should have the same sharding"):
5415+
arr3 = jax.device_put(np.arange(4.).reshape(4, 1),
5416+
NamedSharding(mesh, P('x')))
5417+
f(arr1, arr3)
5418+
5419+
def g(x, y):
5420+
out = f(x, y)
5421+
return jnp.square(jnp.sum(out))
5422+
5423+
out = jax.grad(g)(arr1, arr2)
5424+
self.assertEqual(out.sharding, s)
5425+
5426+
out = jax.jit(jax.grad(g))(arr1, arr2)
5427+
self.assertEqual(out.sharding, s)
5428+
53875429

53885430
@jtu.pytest_mark_if_available('multiaccelerator')
53895431
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)