Skip to content

Commit 9b94180

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Add slice_p and squeeze_p sharding rule to make flash attention work in backward pass
For `slice_p`'s sharding rule, I error out if the operand dim is sharded and the output dim is not divisible by that axis size. I am working on a design to make JAX support uneven sharding at the top level after which slice_p's sharding rule can just `return operand.sharding`. Another option is to add `out_sharding` to `slice` but after uneven sharding support lands, it won't be necessary. PiperOrigin-RevId: 698522980
1 parent d219439 commit 9b94180

File tree

4 files changed

+84
-5
lines changed

4 files changed

+84
-5
lines changed

jax/_src/lax/lax.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4527,6 +4527,12 @@ def _squeeze_dtype_rule(operand, *, dimensions):
45274527
def _squeeze_shape_rule(operand, *, dimensions):
45284528
return _compute_squeeze_shape(np.shape(operand), dimensions)
45294529

4530+
def _squeeze_sharding_rule(operand, *, dimensions):
4531+
dims_set = set(dimensions)
4532+
new_spec = tuple(s for i, s in enumerate(operand.sharding.spec)
4533+
if i not in dims_set)
4534+
return NamedSharding(operand.sharding.mesh, P(*new_spec))
4535+
45304536
def _compute_squeeze_shape(shape, dimensions):
45314537
dims_set = set(dimensions)
45324538
if len(dims_set) != len(dimensions):
@@ -4555,15 +4561,19 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
45554561
return squeeze(operand, dimensions=dimensions), bdim_out
45564562

45574563
squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
4558-
'squeeze')
4564+
'squeeze', sharding_rule=_squeeze_sharding_rule)
45594565
ad.deflinear2(squeeze_p, _squeeze_transpose_rule)
45604566
batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
45614567
pe.def_trivial_padding(squeeze_p)
45624568
batching.ragged_prop_rules[squeeze_p] = batching.ragged_mask_no_op_rule
45634569

45644570
def _squeeze_lower(ctx, operand, *, dimensions):
45654571
del dimensions # Implied by the output aval.
4566-
return [mlir.reshape(ctx, operand, ctx.avals_out[0])]
4572+
aval_out, = ctx.avals_out
4573+
out = mlir.reshape(ctx, operand, aval_out)
4574+
if config.sharding_in_types.value:
4575+
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
4576+
return [out]
45674577

45684578
mlir.register_lowering(squeeze_p, _squeeze_lower)
45694579

jax/_src/lax/slicing.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
_input_dtype,
4343
standard_primitive,
4444
)
45+
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P
4546
from jax._src.lib.mlir import ir
4647
from jax._src.lib.mlir.dialects import hlo
4748
from jax._src.typing import Array, ArrayLike, Shape
@@ -1270,6 +1271,29 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
12701271
return tuple(core.stride_dim(d, window_size=1, window_stride=s)
12711272
for d, s in zip(diff, strides))
12721273

1274+
def _get_sub_spec_size(mesh, sub_spec):
1275+
if isinstance(sub_spec, tuple):
1276+
return math.prod(mesh.shape[s] for s in sub_spec)
1277+
return mesh.shape[sub_spec]
1278+
1279+
def _slice_sharding_rule(operand, *, start_indices, limit_indices, strides):
1280+
# TODO(yashkatariya): Once JAX supports uneven sharding at the top level,
1281+
# change this logic to `return operand.sharding` directly.
1282+
out_shape = _slice_shape_rule(operand, start_indices=start_indices,
1283+
limit_indices=limit_indices, strides=strides)
1284+
mesh = operand.sharding.mesh
1285+
new_spec = []
1286+
for op_sh, out_sh, op_spec in safe_zip(
1287+
operand.shape, out_shape, operand.sharding.spec):
1288+
if (op_sh != out_sh and op_spec is not None and
1289+
out_sh % _get_sub_spec_size(mesh, op_spec) != 0):
1290+
raise NotImplementedError(
1291+
f"slicing on sharded dims where out dim ({out_sh}) is not divisble by"
1292+
f" mesh axes ({_get_sub_spec_size(mesh, op_spec)}) with spec"
1293+
f" ({op_spec}) is not implemented.")
1294+
new_spec.append(op_spec)
1295+
return NamedSharding(mesh, P(*new_spec))
1296+
12731297
def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides):
12741298
assert ad.is_undefined_primal(operand)
12751299
operand_shape = operand.aval.shape
@@ -1308,7 +1332,8 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices,
13081332
out = slice(operand, new_start_indices, new_limit_indices, new_strides)
13091333
return out, bdim
13101334

1311-
slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice')
1335+
slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice',
1336+
sharding_rule=_slice_sharding_rule)
13121337
ad.deflinear2(slice_p, _slice_transpose_rule)
13131338
batching.primitive_batchers[slice_p] = _slice_batching_rule
13141339
# TODO(mvoz): A better slice rule for ragged prop, enforcing boundaries
@@ -1333,8 +1358,11 @@ def _slice_impl(x, start_indices, limit_indices, strides):
13331358
def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
13341359
strides = strides or [1] * len(start_indices)
13351360
aval_out, = ctx.avals_out
1336-
return [mlir.slice_op(ctx, x, aval_out,
1337-
start_indices=start_indices, limit_indices=limit_indices, strides=strides)]
1361+
out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices,
1362+
limit_indices=limit_indices, strides=strides)
1363+
if config.sharding_in_types.value:
1364+
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
1365+
return [out]
13381366

13391367
mlir.register_lowering(slice_p, _slice_lower)
13401368

jax/_src/pallas/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,10 @@ def __init__(self, inner_aval: jax_core.ShapedArray, memory_space: Any):
219219
def __repr__(self) -> str:
220220
return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}'
221221

222+
@property
223+
def sharding(self):
224+
return self.inner_aval.sharding
225+
222226
def update_weak_type(self, weak_type):
223227
return AbstractMemoryRef(
224228
self.inner_aval.update_weak_type(weak_type), self.memory_space)

tests/pjit_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5285,6 +5285,43 @@ def f(x, y):
52855285
self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2)
52865286
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
52875287

5288+
def test_slice(self):
5289+
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
5290+
np_inp = np.arange(16).reshape(4, 4)
5291+
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None)))
5292+
5293+
@jax.jit
5294+
def f(x):
5295+
y = lax.slice(x, (0, 0), (4, 3))
5296+
self.assertEqual(y.sharding.spec, P('x', None))
5297+
return y
5298+
5299+
out = f(arr)
5300+
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
5301+
self.assertIn('@Sharding', f.lower(arr).as_text())
5302+
5303+
with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"):
5304+
f(jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))))
5305+
5306+
with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"):
5307+
f(jax.device_put(np_inp, NamedSharding(mesh, P(None, ('x', 'y')))))
5308+
5309+
def test_squeeze(self):
5310+
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
5311+
np_inp = np.arange(16).reshape(4, 4, 1)
5312+
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None, None)))
5313+
5314+
@jax.jit
5315+
def f(x):
5316+
y = lax.squeeze(x, (2,))
5317+
self.assertEqual(y.sharding.spec, P('x', None))
5318+
return y
5319+
5320+
out = f(arr)
5321+
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
5322+
self.assertIn('@Sharding', f.lower(arr).as_text())
5323+
self.assertArraysEqual(out, np.squeeze(np_inp, axis=2))
5324+
52885325

52895326
@jtu.pytest_mark_if_available('multiaccelerator')
52905327
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)