Skip to content

Commit bc130c7

Browse files
Merge pull request jax-ml#25213 from Rifur13:dynamic_mask
PiperOrigin-RevId: 720361301
2 parents 36679d8 + c0d23af commit bc130c7

File tree

4 files changed

+286
-18
lines changed

4 files changed

+286
-18
lines changed

jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
class SegmentIds(NamedTuple):
4747
"""SegmentIds for Q and KV sequences.
4848
49-
SegmentIds are a mechanims to ensure that there is no cross-attention between
49+
SegmentIds are a mechanism to ensure that there is no cross-attention between
5050
segments (fraction of a sequence) that have been concatenated together into a
5151
sequence. Each array is a list of ids (integers). Only tokens with the same
5252
id are allowed to attend to each other.
@@ -2392,7 +2392,7 @@ def tree_unflatten(cls, kwargs, values):
23922392

23932393

23942394
def _make_splash_attention(
2395-
mask: np.ndarray | mask_lib.MultiHeadMask,
2395+
mask: np.ndarray | jax.Array | mask_lib.MultiHeadMask,
23962396
*,
23972397
block_sizes: BlockSizes | None = None,
23982398
is_mqa: bool,
@@ -2415,14 +2415,26 @@ def _make_splash_attention(
24152415

24162416
if block_sizes is None:
24172417
block_sizes = BlockSizes.get_default()
2418-
fwd_mask_info, mask_function_fwd = mask_info_lib.process_mask(
2418+
2419+
process_mask_fn = (
2420+
mask_info_lib.process_dynamic_mask
2421+
if isinstance(mask, jax.Array)
2422+
else mask_info_lib.process_mask
2423+
)
2424+
2425+
process_mask_dvk_fn = (
2426+
mask_info_lib.process_dynamic_mask_dkv
2427+
if isinstance(mask, jax.Array)
2428+
else mask_info_lib.process_mask_dkv
2429+
)
2430+
2431+
fwd_mask_info, mask_function_fwd = process_mask_fn(
24192432
mask,
24202433
(block_sizes.block_q, block_sizes.block_kv),
24212434
downcast_smem_data=downcast_smem_data,
24222435
head_shards=head_shards,
24232436
q_seq_shards=q_seq_shards,
24242437
)
2425-
24262438
fwd_mask_info = tree_util.tree_map(jnp.array, fwd_mask_info)
24272439

24282440
dq_mask_info = None
@@ -2432,7 +2444,7 @@ def _make_splash_attention(
24322444
dq_mask_info = None
24332445
else:
24342446
bq_dq, bkv_dq = block_sizes.block_q_dq, block_sizes.block_kv_dq
2435-
dq_mask_info, mask_function_dq = mask_info_lib.process_mask(
2447+
dq_mask_info, mask_function_dq = process_mask_fn(
24362448
mask,
24372449
(bq_dq, bkv_dq),
24382450
downcast_smem_data=downcast_smem_data,
@@ -2442,7 +2454,7 @@ def _make_splash_attention(
24422454
assert (mask_function_fwd is None) == (mask_function_dq is None)
24432455
dq_mask_info = tree_util.tree_map(jnp.array, dq_mask_info)
24442456
bq_dkv, bkv_dkv = block_sizes.block_q_dkv, block_sizes.block_kv_dkv
2445-
dkv_mask_info, mask_function_dkv = mask_info_lib.process_mask_dkv(
2457+
dkv_mask_info, mask_function_dkv = process_mask_dvk_fn(
24462458
mask,
24472459
(bq_dkv, bkv_dkv),
24482460
downcast_smem_data=downcast_smem_data,

jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py

Lines changed: 160 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@
1818
import collections
1919
from collections.abc import Callable
2020
import functools
21+
import math
2122
from typing import NamedTuple
23+
24+
import jax
2225
from jax import util as jax_util
2326
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
27+
import jax.numpy as jnp
2428
import numpy as np
2529

2630
# mypy: ignore-errors
@@ -65,10 +69,10 @@ class MaskInfo(NamedTuple):
6569
causal this is just np.arange(q_sequence_length).
6670
"""
6771

68-
data_next: np.ndarray | None
69-
mask_next: np.ndarray | None
70-
block_mask: np.ndarray | None
71-
partial_mask_blocks: np.ndarray | None
72+
data_next: np.ndarray | jax.Array | None
73+
mask_next: np.ndarray | jax.Array | None
74+
block_mask: np.ndarray | jax.Array | None
75+
partial_mask_blocks: np.ndarray | jax.Array | None
7276
q_sequence: np.ndarray | None
7377

7478

@@ -245,7 +249,7 @@ def _get_mask_info_for_shard(
245249
mask_next = np.zeros(output_shape, dtype=np.int32)
246250
data_next = np.zeros(output_shape, dtype=np.int32)
247251

248-
# If the mask is completelly zero'd out return freshly initialized outputs.
252+
# If the mask is completely zero'd out return freshly initialized outputs.
249253
if not data_coords:
250254
return data_next, mask_next
251255

@@ -304,6 +308,152 @@ def _get_mask_info_for_shard(
304308
return data_next, mask_next
305309

306310

311+
def _process_dynamic_mask(
312+
mask: jax.Array,
313+
block_shape: tuple[int, int],
314+
is_dkv: bool,
315+
*,
316+
downcast_smem_data: bool = True,
317+
head_shards: int = 1,
318+
q_seq_shards: int = 1,
319+
shrink_grid: bool = True,
320+
) -> tuple[MaskInfo, None]:
321+
"""Similar to `_process_mask` but the mask must be a dynamic array.
322+
323+
Since the mask is dynamic, we can't know the exact number of partial mask
324+
blocks at trace time. Therefore, the entire mask is materialized in
325+
`partial_mask_blocks`.
326+
327+
Note that we can still populate MaskInfo to skip fully-masked blocks.
328+
329+
Args:
330+
mask: A [head_count, q_seq_len, kv_seq_len] jax.Array representing the dense
331+
mask to process.
332+
block_shape: A Tuple[int, int] representing the shape of the Pallas grid
333+
block.
334+
is_dkv: True if we are processing the dKV mask
335+
downcast_smem_data: If True, downcast the scalar-memory data of MaskInfo to
336+
a data type smaller than np.int32 (if possible).
337+
head_shards: Number of head shards of the mesh in which the kernel is
338+
launched.
339+
q_seq_shards: Number of Q sequence shards of the mesh in which the kernel is
340+
launched.
341+
shrink_grid: Whether or not we should apply the grid shrinking optimization. This is currently ignored.
342+
343+
Returns:
344+
`MaskInfo`, a sparse representation of the dense mask.
345+
346+
Raises:
347+
ValueError: if the input mask is invalid or the block sizes are not
348+
compatible with the mask sizes.
349+
"""
350+
351+
del shrink_grid
352+
353+
# TODO(pobudzey): Properly support sharding.
354+
if head_shards != 1 or q_seq_shards != 1:
355+
raise ValueError('Dynamic mask processing does not support sharding.')
356+
357+
if len(mask.shape) != 3:
358+
raise ValueError(f'Expected a 3-dim mask, instead got: {mask.shape}.')
359+
360+
if mask.dtype != jnp.bool:
361+
raise ValueError(f'Expected a bool mask, instead got: {mask.dtype}.')
362+
363+
head_count, q_seq_len, kv_seq_len = mask.shape
364+
q_block_size, kv_block_size = block_shape
365+
q_blocks_count, q_mod = divmod(q_seq_len, q_block_size)
366+
kv_blocks_count, kv_mod = divmod(kv_seq_len, kv_block_size)
367+
368+
if q_mod != 0:
369+
raise ValueError(f'{q_block_size=} should divide {q_seq_len=}.')
370+
if kv_mod != 0:
371+
raise ValueError(f'{kv_block_size=} should divide {kv_seq_len=}.')
372+
373+
block_mask_shape = (
374+
head_count,
375+
q_blocks_count,
376+
kv_blocks_count,
377+
)
378+
379+
# Tile the last 2 dimensions of the mask into 2D tiles of size `block_shape`.
380+
partial_mask_blocks = (
381+
mask.reshape(
382+
head_count,
383+
q_blocks_count,
384+
q_block_size,
385+
kv_blocks_count,
386+
kv_block_size,
387+
)
388+
.swapaxes(-2, -3)
389+
.astype(np.bool_)
390+
)
391+
392+
# The block mask is 2 for all blocks with all entries set to True and 1 for
393+
# blocks with a mix of True and False entries.
394+
is_full_mask = jnp.all(partial_mask_blocks, axis=(-1, -2))
395+
is_empty_mask = jnp.logical_not(jnp.any(partial_mask_blocks, axis=(-1, -2)))
396+
397+
block_mask = jnp.ones(block_mask_shape, dtype=np.int32)
398+
block_mask = jnp.where(is_full_mask, 2, block_mask)
399+
block_mask = jnp.where(is_empty_mask, 0, block_mask)
400+
401+
# TODO(pobudzey): Return the next valid mask index instead of 0 for a more efficient pipeline.
402+
mask_next = jnp.where(
403+
jnp.logical_or(is_empty_mask, is_full_mask),
404+
0,
405+
jnp.arange(math.prod(block_mask_shape), dtype=np.int32).reshape(
406+
block_mask_shape
407+
),
408+
)
409+
410+
# data_next stores the index of the next non-empty data block in the sequence.
411+
# The indices of empty blocks are set to 0 to avoid copying extra data when
412+
# pipeling.
413+
if is_dkv:
414+
data_next = jnp.arange(q_blocks_count, dtype=np.int32)[None, :, None]
415+
else:
416+
data_next = jnp.arange(kv_blocks_count, dtype=np.int32)[None, None, :]
417+
data_next = jnp.broadcast_to(data_next, block_mask_shape)
418+
data_next = jnp.where(is_empty_mask, 0, data_next)
419+
420+
partial_mask_blocks = partial_mask_blocks.reshape(-1, *block_shape)
421+
if is_dkv:
422+
partial_mask_blocks = partial_mask_blocks.swapaxes(-1, -2)
423+
424+
def _downcast(array: jax.Array, max_value: int) -> jax.Array:
425+
if array.size == 0:
426+
return array
427+
428+
if array.dtype != np.int32:
429+
raise ValueError(f'Expected int32 input, but got {array.dtype}.')
430+
431+
if max_value <= np.iinfo(np.int8).max:
432+
return array.astype(np.int8)
433+
elif max_value <= np.iinfo(np.int16).max:
434+
return array.astype(np.int16)
435+
else:
436+
return array.astype(np.int32)
437+
438+
if downcast_smem_data:
439+
block_mask = block_mask.astype(np.int8) # values are in the range [0, 1, 2]
440+
data_next = _downcast(
441+
data_next, q_blocks_count if is_dkv else kv_blocks_count
442+
)
443+
mask_next = _downcast(mask_next, math.prod(block_mask_shape))
444+
445+
return (
446+
MaskInfo(
447+
data_next=data_next,
448+
mask_next=mask_next,
449+
block_mask=block_mask,
450+
partial_mask_blocks=partial_mask_blocks,
451+
q_sequence=None,
452+
),
453+
None,
454+
)
455+
456+
307457
# When used in a transformer network with multiple layers, the SplashAttention
308458
# kernel is created several times with the same mask. Cache MaskInfo to avoid
309459
# blowing up compile times. Ideally the size of the cache should be determined
@@ -410,7 +560,7 @@ def assign_unique_ids(objects):
410560
mask_id_to_heads[mask_id].append(head)
411561
mask_id_to_head_shards[mask_id].add(head_shard)
412562

413-
# If we have at most one unique mask per each head shard, then we can brodcast
563+
# If we have at most one unique mask per each head shard, then we can broadcast
414564
# the mask to all the heads in the shard. This is the common case.
415565
# If we have more than one mask in each head shard, then the optimization
416566
# cannot kick in and we use one mask for each head.
@@ -699,9 +849,7 @@ def set_block_mask(mask_id: int, q_index: int, kv_index: int, value: int):
699849
current_block_mask,
700850
current_data_next,
701851
current_mask_next,
702-
) in zip(
703-
block_mask_shards, data_next_shards, mask_next_shards
704-
):
852+
) in zip(block_mask_shards, data_next_shards, mask_next_shards):
705853
# For dKV shrinking happens along axis Q (the rows of MaskInfo), for
706854
# fwd and dQ shrinking happens along axis KV (the columns of MaskInfo).
707855
if is_dkv:
@@ -924,3 +1072,6 @@ def _slice_mask_info(
9241072

9251073
process_mask = functools.partial(_process_mask, is_dkv=False)
9261074
process_mask_dkv = functools.partial(_process_mask, is_dkv=True)
1075+
1076+
process_dynamic_mask = functools.partial(_process_dynamic_mask, is_dkv=False)
1077+
process_dynamic_mask_dkv = functools.partial(_process_dynamic_mask, is_dkv=True)

tests/pallas/tpu_splash_attention_kernel_test.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,14 @@ def attn_logits_soft_cap_strategy() -> hps.SearchStrategy[float | None]:
292292
return hps.one_of(hps.just(None), hps.floats(min_value=1.0, max_value=50.0))
293293

294294

295+
def to_dynamic_mask(mask: mask_lib.MultiHeadMask) -> jax.Array:
296+
q_seq_len, kv_seq_len = mask.masks[0].shape
297+
full_mask_slice = (slice(0, q_seq_len), slice(0, kv_seq_len))
298+
dynamic_mask = jnp.stack([m[full_mask_slice] for m in mask.masks], axis=0)
299+
300+
return dynamic_mask
301+
302+
295303
@jtu.with_config(jax_traceback_filtering="off")
296304
class PallasBaseTest(jtu.JaxTestCase):
297305
INTERPRET = False
@@ -322,9 +330,10 @@ class SplashAttentionTest(PallasBaseTest):
322330
@parameterized.product(
323331
is_mqa=(False, True),
324332
is_segmented=(False, True),
333+
is_dynamic_mask=(False, True),
325334
)
326335
@hp.given(hps.data())
327-
def test_splash_attention(self, is_mqa, is_segmented, data):
336+
def test_splash_attention(self, is_mqa, is_segmented, is_dynamic_mask, data):
328337
seed = data.draw(seed_strategy())
329338
key = random.key(seed)
330339
k1, k2, k3 = random.split(key, 3)
@@ -353,6 +362,8 @@ def test_splash_attention(self, is_mqa, is_segmented, data):
353362
attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy())
354363
masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads))
355364
mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks))
365+
if is_dynamic_mask:
366+
mask = to_dynamic_mask(mask)
356367
block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len))
357368

358369
if is_mqa:
@@ -384,10 +395,11 @@ def test_splash_attention(self, is_mqa, is_segmented, data):
384395
@parameterized.product(
385396
is_mqa=(False, True),
386397
is_segmented=(False, True),
398+
is_dynamic_mask=(False, True),
387399
)
388400
@hp.given(hps.data())
389401
def test_splash_attention_fwd(
390-
self, is_mqa, is_segmented, data
402+
self, is_mqa, is_segmented, is_dynamic_mask, data
391403
):
392404
seed = data.draw(seed_strategy())
393405
key = random.key(seed)
@@ -416,6 +428,8 @@ def test_splash_attention_fwd(
416428
attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy())
417429
masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads))
418430
mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks))
431+
if is_dynamic_mask:
432+
mask = to_dynamic_mask(mask)
419433
block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len))
420434
if is_mqa:
421435
attn_ref = splash.make_masked_mqa_reference(mask)
@@ -531,10 +545,17 @@ def test_splash_attention_custom_bwd(self, is_segmented, data):
531545
is_segmented=(False, True),
532546
downcast_smem_data=(False, True),
533547
use_fused_bwd_kernel=(False, True),
548+
use_dynamic_mask=(False, True),
534549
)
535550
@hp.given(hps.data())
536551
def test_splash_attention_bwd(
537-
self, is_mqa, is_segmented, downcast_smem_data, use_fused_bwd_kernel, data
552+
self,
553+
is_mqa,
554+
is_segmented,
555+
downcast_smem_data,
556+
use_fused_bwd_kernel,
557+
use_dynamic_mask,
558+
data,
538559
):
539560
seed = data.draw(seed_strategy())
540561
key = random.key(seed)
@@ -563,6 +584,8 @@ def test_splash_attention_bwd(
563584
attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy())
564585
masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads))
565586
mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks))
587+
if use_dynamic_mask:
588+
mask = to_dynamic_mask(mask)
566589
block_sizes = data.draw(
567590
block_sizes_strategy(q_seq_len, kv_seq_len, include_bwd_blocks=True,
568591
use_fused_bwd_kernel=use_fused_bwd_kernel)

0 commit comments

Comments
 (0)