Skip to content

Commit 64d6d9b

Browse files
Merge pull request #2557 from AI-Hypercomputer:mohit/fix_tokamax_2
PiperOrigin-RevId: 825333068
2 parents df45a16 + 287bc39 commit 64d6d9b

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

src/MaxText/layers/attention_op.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ def validate_flash_attention_with_sinks_on_gpu(sinks: Array | None) -> None:
144144
raise ValueError("The flash attention with sinks is not supported on GPU yet.")
145145

146146

147-
# TODO(agagik): change tokamax_splash_mask._ComputableMask to be non protected
148-
class ChunkedCausalMask(tokamax_splash_mask._ComputableMask): # pylint: disable=protected-access
147+
# TODO(agagik): change splash_attention_mask._ComputableMask to be non protected
148+
class ChunkedCausalMask(splash_attention_mask._ComputableMask): # pylint: disable=protected-access
149149
"""Lazy chunked causal mask.
150150
151151
Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), ... tokens attend to each other but not across chunks.
@@ -1138,10 +1138,11 @@ def create_sa_config(config, query, key, attn_logits_soft_cap):
11381138

11391139
sa_config = create_sa_config(self.config, query, key, attn_logits_soft_cap)
11401140
mask_shape = (query.shape[2], key.shape[2]) # (q_seq_len, kv_seq_len)
1141+
mask_module = tokamax_splash_mask if self.config.use_tokamax_splash else splash_attention_mask
11411142
if self.attention_type == AttentionType.FULL:
1142-
mask = splash_attention_mask.FullMask(mask_shape)
1143+
mask = mask_module.FullMask(mask_shape)
11431144
else:
1144-
mask = splash_attention_mask.CausalMask(shape=mask_shape)
1145+
mask = mask_module.CausalMask(shape=mask_shape)
11451146

11461147
# Create LoadBalancedCausalMask if cp and load_balancing
11471148
if cp_size > 1 and load_balanced_context_parallel:
@@ -1152,7 +1153,7 @@ def create_sa_config(config, query, key, attn_logits_soft_cap):
11521153
if self.attention_type == AttentionType.LOCAL_SLIDING:
11531154
if self.sliding_window_size is None:
11541155
raise ValueError("Sliding_window_size must be set if Local Sliding attention type")
1155-
mask &= splash_attention_mask.LocalMask(
1156+
mask &= mask_module.LocalMask(
11561157
shape=(query.shape[2], key.shape[2]),
11571158
window_size=(self.sliding_window_size, self.sliding_window_size),
11581159
offset=0,
@@ -1775,7 +1776,7 @@ def __call__(
17751776

17761777

17771778
# pylint: disable=protected-access
1778-
class LoadBalancedCausalMask(tokamax_splash_mask._ComputableMask):
1779+
class LoadBalancedCausalMask(splash_attention_mask._ComputableMask):
17791780
"""Lazy causal mask, prevents the model from attending to future tokens.
17801781
Attributes:
17811782
offset: Offset of q start wrt kv. A positive offset shifts the bottom

tests/attention_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,6 @@ def tpu_kernel_attention_helper(self, num_kv_heads):
576576
},
577577
)
578578
# TODO (b/454764135.) : This tests fails with new tokamax kernel
579-
@pytest.mark.skip(reason="Issue w/ tokamax kernel CP->EP sharding correctness. ")
580579
@pytest.mark.tpu_only
581580
def test_tpu_flash_attention_context_parallel(
582581
self, ici_context_parallelism, context_parallel_load_balance, ici_expert_parallelism, expert_shard_attention_option
@@ -1289,7 +1288,6 @@ def test_projection_initialization(self):
12891288
},
12901289
)
12911290
# TODO (b/454764135.) : This tests fails with new tokamax kernel
1292-
@pytest.mark.skip(reason="Issue w/ tokamax kernel CP->EP sharding correctness. ")
12931291
@pytest.mark.tpu_only
12941292
def test_tpu_flash_attention_context_parallel(
12951293
self, ici_context_parallelism, context_parallel_load_balance, ici_expert_parallelism, expert_shard_attention_option

0 commit comments

Comments
 (0)