1919import gc
2020import sys
2121
22+ from flash_sparse_attn .utils .mask import create_mask
23+
2224# Import the compiled CUDA extension
2325try :
2426 from flash_sparse_attn .flash_sparse_attn_interface import flash_sparse_attn_func
@@ -65,42 +67,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
6567 return hidden_states .reshape (batch , num_key_value_heads * n_rep , slen , head_dim )
6668
6769
68- def prepare_mask (
69- hidden_states : torch .Tensor ,
70- attn_bias : torch .Tensor ,
71- causal_mask : torch .Tensor = None ,
72- window_size : int = None ,
73- ):
74- """
75- Args:
76- hidden_states: Input hidden states to determine dtype minimum value
77- attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length)
78- causal_mask: Optional causal mask to apply
79- window_size: Window size of tokens not masked
80-
81- Returns:
82- tuple: (attn_bias, attn_mask)
83- """
84- dtype = hidden_states .dtype
85- min_dtype = torch .finfo (dtype ).min
86-
87- if attn_bias .shape [- 1 ] > window_size :
88- if causal_mask is not None :
89- topk_values , topk_indices = torch .topk (
90- attn_bias .masked_fill (~ causal_mask , min_dtype ).detach (),
91- window_size , dim = - 1 , largest = True , sorted = False
92- )
93- else :
94- topk_values , topk_indices = torch .topk (
95- attn_bias ,
96- window_size , dim = - 1 , largest = True , sorted = False
97- )
98- attn_mask = torch .zeros_like (attn_bias , dtype = torch .bool , device = attn_bias .device ).scatter_ (- 1 , topk_indices , topk_values != min_dtype )
99- else :
100- attn_mask = causal_mask .expand_as (attn_bias ) if causal_mask is not None else torch .ones_like (attn_bias , dtype = torch .bool , device = attn_bias .device )
101- return attn_bias , attn_mask
102-
103-
10470def dynamic_mask_attention_python (
10571 query_states : torch .Tensor ,
10672 key_states : torch .Tensor ,
@@ -127,32 +93,38 @@ def dynamic_mask_attention_python(
12793 Returns:
12894 tuple: (attn_outputs, dq, dk, dv, dbias)
12995 """
130- _ , num_heads , _ , _ = query_states .shape
131- _ , num_kv_heads , _ , _ = key_states .shape
96+ batch_size , num_heads , query_len , _ = query_states .shape
97+ _ , num_kv_heads , key_len , _ = key_states .shape
98+
13299 num_queries_per_kv = num_heads // num_kv_heads
133100
101+ attn_mask = create_mask (
102+ attention_bias = attn_bias ,
103+ attention_mask = causal_mask if is_causal else None ,
104+ batch_size = batch_size ,
105+ query_len = query_len ,
106+ key_len = key_len ,
107+ window_size = window_size ,
108+ min_dtype = torch .finfo (query_states .dtype ).min ,
109+ type = "topk"
110+ )
111+
134112 query_states_leaf = query_states
135113 key_states_leaf = key_states
136114 value_states_leaf = value_states
137-
138- attn_bias , attn_mask = prepare_mask (
139- query_states ,
140- attn_bias ,
141- causal_mask if is_causal else None ,
142- window_size ,
143- )
144115 attn_bias_leaf = attn_bias
145116 attn_bias_leaf .retain_grad ()
146117
147118 key_states = repeat_kv (key_states , num_queries_per_kv )
148119 value_states = repeat_kv (value_states , num_queries_per_kv )
149- attn_mask = repeat_kv (attn_mask , num_queries_per_kv )
120+ attn_mask = repeat_kv (attn_mask , num_queries_per_kv ) if attn_mask is not None else None
150121 attn_bias = repeat_kv (attn_bias_leaf , num_queries_per_kv )
151122
152123 # Sparse attention weight calculation
153124 attn_weights = torch .matmul (query_states , key_states .transpose (- 2 , - 1 )) # Dot product weights
154125 attn_weights = attn_weights * scaling + attn_bias # Apply scaling and bias
155- attn_weights = attn_weights .masked_fill (~ attn_mask , float ('-inf' )) # Apply mask
126+ if attn_mask is not None :
127+ attn_weights = attn_weights .masked_fill (~ attn_mask , float ('-inf' )) # Apply mask
156128 attn_weights = F .softmax (attn_weights , dim = - 1 ) # Softmax normalization
157129 attn_outputs = torch .matmul (attn_weights , value_states ) # Weighted sum of values
158130 attn_outputs = attn_outputs .transpose (1 , 2 ).contiguous () # Transpose to [batch, query_len, num_heads, head_dim]
@@ -192,16 +164,25 @@ def dynamic_mask_attention_cuda(
192164 if flash_sparse_attn_func is None :
193165 raise ImportError ("CUDA implementation not available" )
194166
167+ batch_size , num_heads , query_len , _ = query_states .shape
168+ _ , num_kv_heads , key_len , _ = key_states .shape
169+
170+ num_queries_per_kv = num_heads // num_kv_heads
171+
172+ attn_mask = create_mask (
173+ attention_bias = attn_bias ,
174+ attention_mask = causal_mask if is_causal else None ,
175+ batch_size = batch_size ,
176+ query_len = query_len ,
177+ key_len = key_len ,
178+ window_size = window_size ,
179+ min_dtype = torch .finfo (query_states .dtype ).min ,
180+ type = "topk"
181+ )
182+
195183 query_states_leaf = query_states
196184 key_states_leaf = key_states
197185 value_states_leaf = value_states
198-
199- attn_bias , attn_mask = prepare_mask (
200- query_states ,
201- attn_bias ,
202- causal_mask if is_causal else None ,
203- window_size ,
204- )
205186 attn_bias_leaf = attn_bias
206187 attn_bias_leaf .retain_grad ()
207188
@@ -259,29 +240,28 @@ def dynamic_mask_attention_triton(
259240 if triton_sparse_attn_func is None :
260241 raise RuntimeError ("Triton implementation not available" )
261242
262- _ , num_heads , _ , _ = query_states .shape
263- _ , num_kv_heads , _ , _ = key_states .shape
243+ batch_size , num_heads , query_len , _ = query_states .shape
244+ _ , num_kv_heads , key_len , _ = key_states .shape
245+
264246 num_queries_per_kv = num_heads // num_kv_heads
265247
248+ attn_mask = create_mask (
249+ attention_bias = attn_bias ,
250+ attention_mask = causal_mask if is_causal else None ,
251+ batch_size = batch_size ,
252+ query_len = query_len ,
253+ key_len = key_len ,
254+ window_size = window_size ,
255+ min_dtype = torch .finfo (query_states .dtype ).min ,
256+ type = "topk"
257+ )
258+
266259 query_states_leaf = query_states
267260 key_states_leaf = key_states
268261 value_states_leaf = value_states
269-
270- attn_bias , attn_mask = prepare_mask (
271- query_states ,
272- attn_bias ,
273- causal_mask if is_causal else None ,
274- window_size ,
275- )
276262 attn_bias_leaf = attn_bias
277263 attn_bias_leaf .retain_grad ()
278264
279- # Repeat KV for multi-head attention (GQA support)
280- key_states = repeat_kv (key_states , num_queries_per_kv )
281- value_states = repeat_kv (value_states , num_queries_per_kv )
282- attn_mask = repeat_kv (attn_mask , num_queries_per_kv )
283- attn_bias = repeat_kv (attn_bias_leaf , num_queries_per_kv )
284-
285265 # Ensure correct data types and memory layout for Triton function
286266 query_states = query_states .transpose (1 , 2 ) # [batch, query_len, num_heads, head_dim]
287267 key_states = key_states .transpose (1 , 2 ) # [batch, key_len, num_heads, head_dim]
@@ -333,30 +313,38 @@ def dynamic_mask_attention_flex(
333313 if flex_sparse_attn_func is None :
334314 raise RuntimeError ("Flex Attention implementation not available" )
335315
336- _ , num_heads , _ , _ = query_states .shape
337- _ , num_kv_heads , _ , _ = key_states .shape
316+ batch_size , num_heads , query_len , _ = query_states .shape
317+ _ , num_kv_heads , key_len , _ = key_states .shape
318+
338319 num_queries_per_kv = num_heads // num_kv_heads
339320
340- attn_bias , attn_mask = prepare_mask (
341- query_states ,
342- attn_bias ,
343- causal_mask if is_causal else None ,
344- window_size ,
321+ attn_mask = create_mask (
322+ attention_bias = attn_bias ,
323+ attention_mask = causal_mask if is_causal else None ,
324+ batch_size = batch_size ,
325+ query_len = query_len ,
326+ key_len = key_len ,
327+ window_size = window_size ,
328+ min_dtype = torch .finfo (query_states .dtype ).min ,
329+ type = "topk"
345330 )
346- attn_bias .retain_grad ()
331+
332+ query_states_leaf = query_states
333+ key_states_leaf = key_states
334+ value_states_leaf = value_states
335+ attn_bias_leaf = attn_bias
336+ attn_bias_leaf .retain_grad ()
347337
348338 # Repeat KV for multi-head attention (GQA support)
349339 key_states = repeat_kv (key_states , num_queries_per_kv )
350340 value_states = repeat_kv (value_states , num_queries_per_kv )
351- attn_mask = repeat_kv (attn_mask , num_queries_per_kv )
341+ attn_mask = repeat_kv (attn_mask , num_queries_per_kv ) if attn_mask is not None else None
352342 attn_bias = repeat_kv (attn_bias , num_queries_per_kv )
353343
354344 # Ensure correct data types and memory layout for Flex function
355345 query_states = query_states .transpose (1 , 2 ).contiguous () # [batch, query_len, num_heads, head_dim]
356346 key_states = key_states .transpose (1 , 2 ).contiguous () # [batch, key_len, num_heads, head_dim]
357347 value_states = value_states .transpose (1 , 2 ).contiguous () # [batch, key_len, num_heads, head_dim]
358- attn_mask = attn_mask .contiguous () # [batch, num_heads, seqlen_q, seqlen_k]
359- attn_bias = attn_bias .contiguous () # [batch, num_heads, seqlen_q, seqlen_k]
360348
361349 # Call the Flex Attention implementation
362350 attn_outputs = flex_sparse_attn_func (
@@ -372,7 +360,7 @@ def dynamic_mask_attention_flex(
372360 # Backward pass
373361 attn_outputs .sum ().backward ()
374362
375- return attn_outputs , query_states .grad , key_states .grad , value_states .grad , attn_bias .grad
363+ return attn_outputs , query_states_leaf .grad , key_states_leaf .grad , value_states_leaf .grad , attn_bias_leaf .grad
376364
377365
378366def analyze_differences (original_result , cuda_result , accuracy_threshold = 0.95 ):
@@ -609,7 +597,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
609597 device = device , dtype = dtype , requires_grad = True
610598 )
611599 attn_bias = torch .randn (
612- batch_size , num_kv_heads , query_len , key_len ,
600+ batch_size , num_kv_heads , 1 , key_len ,
613601 device = device , dtype = torch .bfloat16
614602 )
615603 cache_position = torch .arange (key_len - query_len , key_len , device = device )
@@ -843,7 +831,7 @@ def test_triton_backward_equivalence(accuracy_threshold=0.95):
843831 device = device , dtype = dtype , requires_grad = True
844832 )
845833 attn_bias = torch .randn (
846- batch_size , num_kv_heads , query_len , key_len ,
834+ batch_size , num_kv_heads , 1 , key_len ,
847835 device = device , dtype = torch .bfloat16
848836 )
849837 cache_position = torch .arange (key_len - query_len , key_len , device = device )
0 commit comments