Skip to content

Commit e92ca9b

Browse files
Rifur13Google-ML-Automation
authored andcommitted
Use boolean values for partial mask blocks in the splash attention kernel.
The values are guaranteed to be 0 or 1 since we create this array ourselves when processing the masks into a MaskInfo object. PiperOrigin-RevId: 705252534
1 parent b7af1eb commit e92ca9b

File tree

3 files changed

+51
-53
lines changed

3 files changed

+51
-53
lines changed

jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -603,9 +603,9 @@ def _apply_mask_and_soft_cap(
603603
else:
604604
mask = pl.load(mask_ref, (k_slice, slice(None)))
605605

606-
snm = jnp.where(should_not_mask, 1, 0)
607-
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
608-
606+
masks.append(
607+
jnp.bitwise_or(mask, jnp.broadcast_to(should_not_mask, mask.shape))
608+
)
609609
if mask_function is not None:
610610
# Compute the mask using the given q_sequence indices.
611611
# KV indices are computed on the fly. This works because we only support Q
@@ -900,6 +900,16 @@ def _splash_attention_forward(
900900
kv_seq_len_dimension = 1
901901
num_kv_heads = k.shape[0]
902902

903+
partial_mask_blocks = fwd_mask_info.partial_mask_blocks
904+
if (
905+
partial_mask_blocks is not None
906+
and jnp.dtype(partial_mask_blocks.dtype) != np.bool_
907+
):
908+
raise ValueError(
909+
"partial_mask_blocks must be of type np.bool_ but got"
910+
f" {partial_mask_blocks.dtype}"
911+
)
912+
903913
if len(k.shape) != expected_kv_rank:
904914
raise ValueError(
905915
f"Expected {expected_kv_rank}-dim 'key' tensor for MQA. Instead got a"

jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class MaskInfo(NamedTuple):
5656
indicates that the corresponding block in the full mask contained both
5757
zeros and ones. An entry of 2 indicates the corresponding block was
5858
entirely ones.
59-
partial_mask_blocks: A i32[num_partial_blocks, block_q, block_kv] NumPy
59+
partial_mask_blocks: A bool[num_partial_blocks, block_q, block_kv] NumPy
6060
array that contains the blocks of the original mask that contained both
6161
zeros and ones. The entries in `mask_next` point to indices in the first
6262
axis of this array.
@@ -305,7 +305,7 @@ def _get_mask_info_for_shard(
305305

306306

307307
# When used in a transformer network with multiple layers, the SplashAttention
308-
# kernel is created serveral times with the same mask. Cache MaskInfo to avoid
308+
# kernel is created several times with the same mask. Cache MaskInfo to avoid
309309
# blowing up compile times. Ideally the size of the cache should be determined
310310
# by the client.
311311
@functools.lru_cache(maxsize=12)
@@ -376,14 +376,6 @@ def _process_mask(
376376
if mod != 0:
377377
raise ValueError(f'{head_shards=} should divide {head_count=}.')
378378

379-
first_mask_size = mask.masks[0].shape
380-
for h in range(head_count):
381-
if mask.masks[h].shape != first_mask_size:
382-
raise ValueError(
383-
f'First head mask has shape {first_mask_size}, but head mask {h} has'
384-
f' shape {mask.masks[h].shape}. All head masks must have the same'
385-
' shape.'
386-
)
387379

388380
# Uniquify the masks.
389381
# Create a collection of the unique head masks in the input multi-head mask.
@@ -526,13 +518,9 @@ def set_block_mask(mask_id: int, q_index: int, kv_index: int, value: int):
526518

527519
partial_mask_blocks = None
528520
has_mask_next = False
529-
if len(unique_partial_mask_blocks) == 1:
530-
partial_mask_blocks = [x.array for x in unique_partial_mask_blocks]
531-
partial_mask_blocks = partial_mask_blocks[0][None].astype(np.int32)
532-
has_mask_next = True
533-
elif len(unique_partial_mask_blocks) > 1:
521+
if len(unique_partial_mask_blocks) >= 1:
534522
partial_mask_blocks = [x.array for x in unique_partial_mask_blocks]
535-
partial_mask_blocks = np.stack(partial_mask_blocks, axis=0).astype(np.int32)
523+
partial_mask_blocks = np.stack(partial_mask_blocks, axis=0).astype(np.bool_)
536524
has_mask_next = True
537525
if is_dkv and partial_mask_blocks is not None:
538526
partial_mask_blocks = np.swapaxes(partial_mask_blocks, -1, -2)

tests/pallas/tpu_splash_attention_mask_test.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
from absl.testing import absltest
1919
from absl.testing import parameterized
2020
import jax
21+
from jax._src import test_util as jtu
2122
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
2223
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask_info as mask_info_lib
23-
from jax._src import test_util as jtu
2424
import numpy as np
2525

2626
jax.config.parse_flags_with_absl()
@@ -798,7 +798,7 @@ def test_two_causal_masks(self, is_lazy_mask: bool):
798798
self._expected_causal_data_next[None],
799799
self._expected_causal_mask_next(0)[None] if not is_lazy_mask else None,
800800
self._expected_causal_block_mask[None],
801-
np.expand_dims(np.tril(np.ones(block_shape, dtype=np.int32)), 0)
801+
np.expand_dims(np.tril(np.ones(block_shape, dtype=np.bool_)), 0)
802802
if not is_lazy_mask
803803
else None,
804804
np.arange(sequence_lengths[0], dtype=np.int32)
@@ -813,7 +813,7 @@ def test_two_causal_masks(self, is_lazy_mask: bool):
813813
else None,
814814
self._expected_causal_block_mask_dkv[None],
815815
np.expand_dims(
816-
np.tril(np.ones(block_shape, dtype=np.int32)), 0
816+
np.tril(np.ones(block_shape, dtype=np.bool_)), 0
817817
).swapaxes(-1, -2)
818818
if not is_lazy_mask
819819
else None,
@@ -851,7 +851,7 @@ def test_rectangular_wide_causal_mask(self, is_lazy_mask: bool):
851851
self._expected_causal_data_next[None],
852852
self._expected_causal_mask_next(0)[None] if not is_lazy_mask else None,
853853
self._expected_causal_block_mask[None],
854-
np.expand_dims(np.tril(np.ones(block_shape, dtype=np.int32)), 0)
854+
np.expand_dims(np.tril(np.ones(block_shape, dtype=np.bool_)), 0)
855855
if not is_lazy_mask
856856
else None,
857857
np.arange(sequence_lengths[0], dtype=np.int32)
@@ -894,7 +894,7 @@ def test_rectangular_wide_causal_mask(self, is_lazy_mask: bool):
894894
expected_causal_mask_next_dkv if not is_lazy_mask else None,
895895
expected_causal_block_mask_dkv,
896896
np.expand_dims(
897-
np.tril(np.ones(block_shape, dtype=np.int32)), 0
897+
np.tril(np.ones(block_shape, dtype=np.bool_)), 0
898898
).swapaxes(-1, -2)
899899
if not is_lazy_mask
900900
else None,
@@ -974,7 +974,7 @@ def test_rectangular_tall_causal_mask(self, is_lazy_mask: bool):
974974
expected_causal_data_next,
975975
expected_causal_mask_next if not is_lazy_mask else None,
976976
expected_causal_block_mask,
977-
np.expand_dims(np.tril(np.ones(block_shape, dtype=np.int32)), 0)
977+
np.expand_dims(np.tril(np.ones(block_shape, dtype=np.bool_)), 0)
978978
if not is_lazy_mask
979979
else None,
980980
np.arange(sequence_lengths[0], dtype=np.int32)
@@ -1029,7 +1029,7 @@ def test_rectangular_tall_causal_mask(self, is_lazy_mask: bool):
10291029
expected_causal_mask_next_dkv if not is_lazy_mask else None,
10301030
expected_causal_block_mask_dkv,
10311031
np.expand_dims(
1032-
np.tril(np.ones(block_shape, dtype=np.int32)), 0
1032+
np.tril(np.ones(block_shape, dtype=np.bool_)), 0
10331033
).swapaxes(-1, -2)
10341034
if not is_lazy_mask
10351035
else None,
@@ -1069,10 +1069,10 @@ def test_local_mask(self, is_lazy_mask: bool):
10691069
expected_partial_mask_blocks = self._stack(
10701070
[
10711071
np.triu(
1072-
np.tri(*block_shape, window_size, dtype=np.int32), -window_size
1072+
np.tri(*block_shape, window_size, dtype=np.bool_), -window_size
10731073
),
1074-
np.tri(*block_shape, -window_size, dtype=np.int32),
1075-
np.triu(np.ones(block_shape, dtype=np.int32), window_size),
1074+
np.tri(*block_shape, -window_size, dtype=np.bool_),
1075+
np.triu(np.ones(block_shape, dtype=np.bool_), window_size),
10761076
],
10771077
)
10781078

@@ -1179,8 +1179,8 @@ def test_local_mask_narrow(self, is_lazy_mask: bool):
11791179

11801180
expected_partial_mask_blocks = self._stack(
11811181
[
1182-
np.triu(np.tri(*block_shape, 0, dtype=np.int32), -window_size),
1183-
np.triu(np.ones(block_shape, dtype=np.int32), window_size),
1182+
np.triu(np.tri(*block_shape, 0, dtype=np.bool_), -window_size),
1183+
np.triu(np.ones(block_shape, dtype=np.bool_), window_size),
11841184
],
11851185
)
11861186

@@ -1298,13 +1298,13 @@ def test_two_head_shards_one_causal_one_local(self, is_lazy_mask: bool):
12981298
)
12991299

13001300
expected_partial_mask_blocks = self._stack([
1301-
np.tril(np.ones(block_shape, dtype=np.int32)),
1301+
np.tril(np.ones(block_shape, dtype=np.bool_)),
13021302
np.triu(
1303-
np.tri(*block_shape, window_size, dtype=np.int32),
1303+
np.tri(*block_shape, window_size, dtype=np.bool_),
13041304
-window_size,
13051305
),
1306-
np.tri(*block_shape, -window_size, dtype=np.int32),
1307-
np.triu(np.ones(block_shape, dtype=np.int32), window_size),
1306+
np.tri(*block_shape, -window_size, dtype=np.bool_),
1307+
np.triu(np.ones(block_shape, dtype=np.bool_), window_size),
13081308
])
13091309

13101310
expected_block_mask_dkv = self._stack(
@@ -1384,7 +1384,7 @@ def test_two_head_shards_causal_full(self, is_lazy_mask: bool):
13841384
])
13851385

13861386
expected_partial_mask_blocks = np.expand_dims(
1387-
np.tril(np.ones(block_shape, dtype=np.int32)), 0
1387+
np.tril(np.ones(block_shape, dtype=np.bool_)), 0
13881388
)
13891389

13901390
expected_mask_info = mask_info_lib.MaskInfo(
@@ -1460,13 +1460,13 @@ def test_two_qseq_shards_causal_local(self, is_lazy_mask: bool):
14601460
)
14611461

14621462
expected_partial_mask_blocks = self._stack([
1463-
np.tril(np.ones(block_shape, dtype=np.int32)),
1463+
np.tril(np.ones(block_shape, dtype=np.bool_)),
14641464
np.triu(
1465-
np.tri(*block_shape, window_size, dtype=np.int32),
1465+
np.tri(*block_shape, window_size, dtype=np.bool_),
14661466
-window_size,
14671467
),
1468-
np.tri(*block_shape, -window_size, dtype=np.int32),
1469-
np.triu(np.ones(block_shape, dtype=np.int32), window_size),
1468+
np.tri(*block_shape, -window_size, dtype=np.bool_),
1469+
np.triu(np.ones(block_shape, dtype=np.bool_), window_size),
14701470
])
14711471

14721472
expected_mask_info = mask_info_lib.MaskInfo(
@@ -1577,13 +1577,13 @@ def test_two_qseq_shards_causal_local_stacked(self):
15771577
)
15781578

15791579
expected_partial_mask_blocks = self._stack([
1580-
np.tril(np.ones(block_shape, dtype=np.int32)),
1580+
np.tril(np.ones(block_shape, dtype=np.bool_)),
15811581
np.triu(
1582-
np.tri(*block_shape, window_size, dtype=np.int32),
1582+
np.tri(*block_shape, window_size, dtype=np.bool_),
15831583
-window_size,
15841584
),
1585-
np.tri(*block_shape, -window_size, dtype=np.int32),
1586-
np.triu(np.ones(block_shape, dtype=np.int32), window_size),
1585+
np.tri(*block_shape, -window_size, dtype=np.bool_),
1586+
np.triu(np.ones(block_shape, dtype=np.bool_), window_size),
15871587
])
15881588

15891589
expected_mask_info = mask_info_lib.MaskInfo(
@@ -1749,13 +1749,13 @@ def test_two_qseq_shards_local_wide_local_narrow_stacked(self):
17491749
expected_partial_mask_blocks = self._stack([
17501750
# Wide
17511751
np.triu(
1752-
np.tri(*block_shape, window_size, dtype=np.int32),
1752+
np.tri(*block_shape, window_size, dtype=np.bool_),
17531753
-window_size,
17541754
),
1755-
np.tri(*block_shape, -window_size, dtype=np.int32),
1756-
np.triu(np.ones(block_shape, dtype=np.int32), window_size),
1755+
np.tri(*block_shape, -window_size, dtype=np.bool_),
1756+
np.triu(np.ones(block_shape, dtype=np.bool_), window_size),
17571757
# Narrow
1758-
np.triu(np.tri(*block_shape, 0, dtype=np.int32), -window_size),
1758+
np.triu(np.tri(*block_shape, 0, dtype=np.bool_), -window_size),
17591759
])
17601760

17611761
expected_mask_info = mask_info_lib.MaskInfo(
@@ -1890,7 +1890,7 @@ def test_two_head_shards_causal_mask(self, is_lazy_mask: bool):
18901890
)
18911891

18921892
expected_partial_mask_blocks = np.expand_dims(
1893-
np.tril(np.ones(block_shape, dtype=np.int32)), 0
1893+
np.tril(np.ones(block_shape, dtype=np.bool_)), 0
18941894
)
18951895

18961896
expected_mask_info = mask_info_lib.MaskInfo(
@@ -1979,13 +1979,13 @@ def test_two_head_shards_two_causal_two_local(self, is_lazy_mask: bool):
19791979

19801980
expected_partial_mask_blocks = self._stack(
19811981
[
1982-
np.tril(np.ones(block_shape, dtype=np.int32)),
1982+
np.tril(np.ones(block_shape, dtype=np.bool_)),
19831983
np.triu(
1984-
np.tri(*block_shape, window_size, dtype=np.int32),
1984+
np.tri(*block_shape, window_size, dtype=np.bool_),
19851985
-window_size,
19861986
),
1987-
np.tri(*block_shape, -window_size, dtype=np.int32),
1988-
np.triu(np.ones(block_shape, dtype=np.int32), window_size),
1987+
np.tri(*block_shape, -window_size, dtype=np.bool_),
1988+
np.triu(np.ones(block_shape, dtype=np.bool_), window_size),
19891989
],
19901990
)
19911991

0 commit comments

Comments
 (0)