Skip to content

Commit 80784a5

Browse files
Merge pull request jax-ml#26387 from Rifur13:sharding
PiperOrigin-RevId: 738919611
2 parents ea7fa29 + 412f1d3 commit 80784a5

File tree

6 files changed

+351
-40
lines changed

6 files changed

+351
-40
lines changed

jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2293,6 +2293,26 @@ def _splash_attention(
22932293
mask_function: MaskFunctionType | None,
22942294
interpret: bool,
22952295
) -> SplashCustomReturnType:
2296+
"""
2297+
For dynamic masks, `partial_mask_blocks` has shape (head_count, q_blocks, kv_blocks, block_q, block_kv).
2298+
This shape allows sharding across both head count and query sequence dimensions.
2299+
2300+
Note: The leading dimensions (head_count, q_blocks, kv_blocks) must be
2301+
collapsed into a single dimension before being passed to the kernel.
2302+
"""
2303+
def _collapse_partial_mask_blocks(mask_info: mask_info_lib.MaskInfo | None):
2304+
if mask_info is None or mask_info.partial_mask_blocks is None:
2305+
return mask_info
2306+
2307+
return mask_info._replace(
2308+
partial_mask_blocks=mask_info.partial_mask_blocks.reshape(
2309+
-1, *mask_info.partial_mask_blocks.shape[-2:]
2310+
)
2311+
)
2312+
2313+
fwd_mask_info = _collapse_partial_mask_blocks(fwd_mask_info)
2314+
dq_mask_info = _collapse_partial_mask_blocks(dq_mask_info)
2315+
dkv_mask_info = _collapse_partial_mask_blocks(dkv_mask_info)
22962316
return _splash_attention_custom(
22972317
fwd_mask_info,
22982318
dq_mask_info,
@@ -2352,13 +2372,16 @@ def manual_sharding_spec(self, sharding: jax.sharding.NamedSharding):
23522372
spec = sharding.spec
23532373
assert len(spec) == 2
23542374
replicated = jax.sharding.PartitionSpec()
2375+
partial_mask_blocks_spec = (
2376+
spec if self.fwd_mask_info.is_dynamic_mask else replicated
2377+
)
23552378
# Shard q_sequence over the sequence dimension only.
23562379
q_sequence_spec = jax.sharding.PartitionSpec(spec[1])
23572380
mask_info_specs = mask_info_lib.MaskInfo( # pytype: disable=wrong-arg-types
23582381
data_next=spec if self.fwd_mask_info.data_next is not None else None,
23592382
mask_next=spec if self.fwd_mask_info.mask_next is not None else None,
23602383
block_mask=spec if self.fwd_mask_info.block_mask is not None else None,
2361-
partial_mask_blocks=replicated
2384+
partial_mask_blocks=partial_mask_blocks_spec
23622385
if self.fwd_mask_info.partial_mask_blocks is not None
23632386
else None,
23642387
q_sequence=q_sequence_spec

jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py

Lines changed: 83 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,18 @@ class MaskInfo(NamedTuple):
6767
q_sequence: A i32[q_sequence_length] NumPy array. When using causal masking,
6868
this contains the list of indices that correspond to q tokens. For plain
6969
causal this is just np.arange(q_sequence_length).
70+
is_dynamic_mask: A bool indicating whether the mask is dynamic or static.
71+
When True, the leading dimensions of `partial_mask_blocks` (num_heads,
72+
q_blocks, kv_blocks) are not collapsed, allowing us to shard it along
73+
those dimensions.
7074
"""
7175

7276
data_next: np.ndarray | jax.Array | None
7377
mask_next: np.ndarray | jax.Array | None
7478
block_mask: np.ndarray | jax.Array | None
7579
partial_mask_blocks: np.ndarray | jax.Array | None
7680
q_sequence: np.ndarray | None
81+
is_dynamic_mask: bool = None
7782

7883

7984
def _downcast_to_small_type(array: np.ndarray) -> np.ndarray:
@@ -168,7 +173,7 @@ def __eq__(self, other: object) -> bool:
168173
def _get_mask_info_for_shard(
169174
output_shape: tuple[int, int, int],
170175
has_mask_next: bool,
171-
mask: mask_lib.MultiHeadMask,
176+
mask: mask_lib.MultiHeadMask | jax.Array,
172177
block_shape: tuple[int, int],
173178
coords_to_partial_mask_block_index: dict[tuple[int, int, int], int],
174179
masks_per_head_shard: int,
@@ -338,7 +343,8 @@ def _process_dynamic_mask(
338343
launched.
339344
q_seq_shards: Number of Q sequence shards of the mesh in which the kernel is
340345
launched.
341-
shrink_grid: Whether or not we should apply the grid shrinking optimization. This is currently ignored.
346+
shrink_grid: Whether or not we should apply the grid shrinking optimization.
347+
This is currently ignored.
342348
343349
Returns:
344350
`MaskInfo`, a sparse representation of the dense mask.
@@ -349,11 +355,6 @@ def _process_dynamic_mask(
349355
"""
350356

351357
del shrink_grid
352-
353-
# TODO(pobudzey): Properly support sharding.
354-
if head_shards != 1 or q_seq_shards != 1:
355-
raise ValueError('Dynamic mask processing does not support sharding.')
356-
357358
if len(mask.shape) != 3:
358359
raise ValueError(f'Expected a 3-dim mask, instead got: {mask.shape}.')
359360

@@ -370,6 +371,18 @@ def _process_dynamic_mask(
370371
if kv_mod != 0:
371372
raise ValueError(f'{kv_block_size=} should divide {kv_seq_len=}.')
372373

374+
q_seq_len_per_shard, mod = divmod(q_seq_len, q_seq_shards)
375+
if mod != 0:
376+
raise ValueError(f'{q_seq_shards=} should divide {q_seq_len=}.')
377+
378+
q_blocks_per_shard, mod = divmod(q_seq_len_per_shard, q_block_size)
379+
if mod != 0:
380+
raise ValueError(f'{q_block_size=} should divide {q_seq_len_per_shard=}.')
381+
382+
heads_per_shard, mod = divmod(head_count, head_shards)
383+
if mod != 0:
384+
raise ValueError(f'{head_shards=} should divide {head_count=}.')
385+
373386
block_mask_shape = (
374387
head_count,
375388
q_blocks_count,
@@ -398,26 +411,66 @@ def _process_dynamic_mask(
398411
block_mask = jnp.where(is_full_mask, 2, block_mask)
399412
block_mask = jnp.where(is_empty_mask, 0, block_mask)
400413

401-
# TODO(pobudzey): Return the next valid mask index instead of 0 for a more efficient pipeline.
402-
mask_next = jnp.where(
403-
jnp.logical_or(is_empty_mask, is_full_mask),
404-
0,
405-
jnp.arange(math.prod(block_mask_shape), dtype=np.int32).reshape(
406-
block_mask_shape
407-
),
408-
)
414+
q_sequence_axis = 1
415+
head_axis = 0
409416

410-
# data_next stores the index of the next non-empty data block in the sequence.
411-
# The indices of empty blocks are set to 0 to avoid copying extra data when
412-
# pipeling.
413-
if is_dkv:
414-
data_next = jnp.arange(q_blocks_count, dtype=np.int32)[None, :, None]
415-
else:
416-
data_next = jnp.arange(kv_blocks_count, dtype=np.int32)[None, None, :]
417-
data_next = jnp.broadcast_to(data_next, block_mask_shape)
418-
data_next = jnp.where(is_empty_mask, 0, data_next)
417+
# Each iteration of the loop processes a slice of the mask info
418+
# tensors of this shape:
419+
mask_info_slice_shape = (heads_per_shard, q_blocks_per_shard, kv_blocks_count)
420+
421+
# Collect mask_info shards along the head dimension, concatentate (or
422+
# broadcast) them after the loop.
423+
data_next_per_head_list, mask_next_per_head_list = [], []
424+
for head_shard in range(head_shards):
425+
head_start = head_shard * heads_per_shard
426+
mask_head_slice = slice(head_start, head_start + heads_per_shard)
427+
428+
# Collect mask_info shards along the q_sequence dimension, concatenate them
429+
# after the loop.
430+
data_next_sequence_slices, mask_next_sequence_slices = [], []
431+
for q_seq_len_shard in range(q_seq_shards):
432+
q_seq_len_start = q_seq_len_shard * q_blocks_per_shard
433+
blocked_q_seq_len_slice = slice(
434+
q_seq_len_start, q_seq_len_start + q_blocks_per_shard
435+
)
436+
local_block_mask = block_mask[mask_head_slice, blocked_q_seq_len_slice]
437+
438+
mask_next_slice = jnp.arange(
439+
math.prod(mask_info_slice_shape), dtype=np.int32
440+
).reshape(mask_info_slice_shape)
441+
mask_next_slice = jnp.where(local_block_mask == 1, mask_next_slice, 0)
442+
443+
# data_next stores the index of the next non-empty data block in the sequence.
444+
# The indices of empty blocks are set to 0 to avoid copying extra data when
445+
# pipeling.
446+
if is_dkv:
447+
data_next_slice = jnp.arange(q_blocks_per_shard, dtype=np.int32)[
448+
None, :, None
449+
]
450+
else:
451+
data_next_slice = jnp.arange(kv_blocks_count, dtype=np.int32)[
452+
None, None, :
453+
]
454+
data_next_slice = jnp.broadcast_to(data_next_slice, mask_info_slice_shape)
455+
data_next_slice = jnp.where(local_block_mask == 0, 0, data_next_slice)
456+
457+
data_next_sequence_slices.append(data_next_slice)
458+
mask_next_sequence_slices.append(mask_next_slice)
459+
460+
# Concatenate the sequence shards.
461+
data_next_per_head = jnp.concatenate(
462+
data_next_sequence_slices, axis=q_sequence_axis
463+
)
464+
data_next_per_head_list.append(data_next_per_head)
465+
mask_next_per_head = jnp.concatenate(
466+
mask_next_sequence_slices, axis=q_sequence_axis
467+
)
468+
mask_next_per_head_list.append(mask_next_per_head)
469+
470+
# Concatenate (or broadcast) the head shards.
471+
data_next = jnp.concatenate(data_next_per_head_list, axis=head_axis)
472+
mask_next = jnp.concatenate(mask_next_per_head_list, axis=head_axis)
419473

420-
partial_mask_blocks = partial_mask_blocks.reshape(-1, *block_shape)
421474
if is_dkv:
422475
partial_mask_blocks = partial_mask_blocks.swapaxes(-1, -2)
423476

@@ -438,9 +491,11 @@ def _downcast(array: jax.Array, max_value: int) -> jax.Array:
438491
if downcast_smem_data:
439492
block_mask = block_mask.astype(np.int8) # values are in the range [0, 1, 2]
440493
data_next = _downcast(
441-
data_next, q_blocks_count if is_dkv else kv_blocks_count
494+
data_next, q_blocks_per_shard if is_dkv else kv_blocks_count
495+
)
496+
mask_next = _downcast(
497+
mask_next, heads_per_shard * q_blocks_per_shard * kv_blocks_count
442498
)
443-
mask_next = _downcast(mask_next, math.prod(block_mask_shape))
444499

445500
return (
446501
MaskInfo(
@@ -449,6 +504,7 @@ def _downcast(array: jax.Array, max_value: int) -> jax.Array:
449504
block_mask=block_mask,
450505
partial_mask_blocks=partial_mask_blocks,
451506
q_sequence=None,
507+
is_dynamic_mask=True,
452508
),
453509
None,
454510
)

tests/pallas/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,21 @@ jax_multiplatform_test(
540540
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
541541
)
542542

543+
jax_multiplatform_test(
544+
name = "tpu_splash_attention_kernel_sharded_test",
545+
srcs = ["tpu_splash_attention_kernel_sharded_test.py"],
546+
enable_configs = [
547+
"tpu_v5e_4x2",
548+
"tpu_v5p_2x2",
549+
],
550+
shard_count = 5,
551+
deps = [
552+
"//jax:extend",
553+
"//jax:pallas_tpu",
554+
"//jax:pallas_tpu_ops",
555+
],
556+
)
557+
543558
# This test doesn't need a TPU; it only tests numpy-using helpers.
544559
jax_py_test(
545560
name = "tpu_splash_attention_mask_test",

0 commit comments

Comments
 (0)