@@ -144,8 +144,8 @@ def validate_flash_attention_with_sinks_on_gpu(sinks: Array | None) -> None:
144144 raise ValueError ("The flash attention with sinks is not supported on GPU yet." )
145145
146146
147- # TODO(agagik): change tokamax_splash_mask ._ComputableMask to be non protected
148- class ChunkedCausalMask (tokamax_splash_mask ._ComputableMask ): # pylint: disable=protected-access
147+ # TODO(agagik): change splash_attention_mask ._ComputableMask to be non protected
148+ class ChunkedCausalMask (splash_attention_mask ._ComputableMask ): # pylint: disable=protected-access
149149 """Lazy chunked causal mask.
150150
151151 Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), ... tokens attend to each other but not across chunks.
@@ -1138,10 +1138,11 @@ def create_sa_config(config, query, key, attn_logits_soft_cap):
11381138
11391139 sa_config = create_sa_config (self .config , query , key , attn_logits_soft_cap )
11401140 mask_shape = (query .shape [2 ], key .shape [2 ]) # (q_seq_len, kv_seq_len)
1141+ mask_module = tokamax_splash_mask if self .config .use_tokamax_splash else splash_attention_mask
11411142 if self .attention_type == AttentionType .FULL :
1142- mask = splash_attention_mask .FullMask (mask_shape )
1143+ mask = mask_module .FullMask (mask_shape )
11431144 else :
1144- mask = splash_attention_mask .CausalMask (shape = mask_shape )
1145+ mask = mask_module .CausalMask (shape = mask_shape )
11451146
11461147 # Create LoadBalancedCausalMask if cp and load_balancing
11471148 if cp_size > 1 and load_balanced_context_parallel :
@@ -1152,7 +1153,7 @@ def create_sa_config(config, query, key, attn_logits_soft_cap):
11521153 if self .attention_type == AttentionType .LOCAL_SLIDING :
11531154 if self .sliding_window_size is None :
11541155 raise ValueError ("Sliding_window_size must be set if Local Sliding attention type" )
1155- mask &= splash_attention_mask .LocalMask (
1156+ mask &= mask_module .LocalMask (
11561157 shape = (query .shape [2 ], key .shape [2 ]),
11571158 window_size = (self .sliding_window_size , self .sliding_window_size ),
11581159 offset = 0 ,
@@ -1775,7 +1776,7 @@ def __call__(
17751776
17761777
17771778# pylint: disable=protected-access
1778- class LoadBalancedCausalMask (tokamax_splash_mask ._ComputableMask ):
1779+ class LoadBalancedCausalMask (splash_attention_mask ._ComputableMask ):
17791780 """Lazy causal mask, prevents the model from attending to future tokens.
17801781 Attributes:
17811782 offset: Offset of q start wrt kv. A positive offset shifts the bottom
0 commit comments