Skip to content

Commit ed981b0

Browse files
authored
Merge pull request #209 from flash-algo/optim-banchmark
[BUG FIX] Unify masking utilities and improve performance
2 parents cfeeef0 + ade4ef3 commit ed981b0

File tree

4 files changed

+240
-290
lines changed

4 files changed

+240
-290
lines changed

benchmarks/backward_equivalence.py

Lines changed: 71 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import gc
2020
import sys
2121

22+
from flash_sparse_attn.utils.mask import create_mask
23+
2224
# Import the compiled CUDA extension
2325
try:
2426
from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func
@@ -65,42 +67,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
6567
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
6668

6769

68-
def prepare_mask(
69-
hidden_states: torch.Tensor,
70-
attn_bias: torch.Tensor,
71-
causal_mask: torch.Tensor = None,
72-
window_size: int = None,
73-
):
74-
"""
75-
Args:
76-
hidden_states: Input hidden states to determine dtype minimum value
77-
attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length)
78-
causal_mask: Optional causal mask to apply
79-
window_size: Window size of tokens not masked
80-
81-
Returns:
82-
tuple: (attn_bias, attn_mask)
83-
"""
84-
dtype = hidden_states.dtype
85-
min_dtype = torch.finfo(dtype).min
86-
87-
if attn_bias.shape[-1] > window_size:
88-
if causal_mask is not None:
89-
topk_values, topk_indices = torch.topk(
90-
attn_bias.masked_fill(~causal_mask, min_dtype).detach(),
91-
window_size, dim=-1, largest=True, sorted=False
92-
)
93-
else:
94-
topk_values, topk_indices = torch.topk(
95-
attn_bias,
96-
window_size, dim=-1, largest=True, sorted=False
97-
)
98-
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype)
99-
else:
100-
attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
101-
return attn_bias, attn_mask
102-
103-
10470
def dynamic_mask_attention_python(
10571
query_states: torch.Tensor,
10672
key_states: torch.Tensor,
@@ -127,32 +93,38 @@ def dynamic_mask_attention_python(
12793
Returns:
12894
tuple: (attn_outputs, dq, dk, dv, dbias)
12995
"""
130-
_, num_heads, _, _ = query_states.shape
131-
_, num_kv_heads, _, _ = key_states.shape
96+
batch_size, num_heads, query_len, _ = query_states.shape
97+
_, num_kv_heads, key_len, _ = key_states.shape
98+
13299
num_queries_per_kv = num_heads // num_kv_heads
133100

101+
attn_mask = create_mask(
102+
attention_bias=attn_bias,
103+
attention_mask=causal_mask if is_causal else None,
104+
batch_size=batch_size,
105+
query_len=query_len,
106+
key_len=key_len,
107+
window_size=window_size,
108+
min_dtype=torch.finfo(query_states.dtype).min,
109+
type="topk"
110+
)
111+
134112
query_states_leaf = query_states
135113
key_states_leaf = key_states
136114
value_states_leaf = value_states
137-
138-
attn_bias, attn_mask = prepare_mask(
139-
query_states,
140-
attn_bias,
141-
causal_mask if is_causal else None,
142-
window_size,
143-
)
144115
attn_bias_leaf = attn_bias
145116
attn_bias_leaf.retain_grad()
146117

147118
key_states = repeat_kv(key_states, num_queries_per_kv)
148119
value_states = repeat_kv(value_states, num_queries_per_kv)
149-
attn_mask = repeat_kv(attn_mask, num_queries_per_kv)
120+
attn_mask = repeat_kv(attn_mask, num_queries_per_kv) if attn_mask is not None else None
150121
attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv)
151122

152123
# Sparse attention weight calculation
153124
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) # Dot product weights
154125
attn_weights = attn_weights * scaling + attn_bias # Apply scaling and bias
155-
attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask
126+
if attn_mask is not None:
127+
attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask
156128
attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization
157129
attn_outputs = torch.matmul(attn_weights, value_states) # Weighted sum of values
158130
attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim]
@@ -192,16 +164,25 @@ def dynamic_mask_attention_cuda(
192164
if flash_sparse_attn_func is None:
193165
raise ImportError("CUDA implementation not available")
194166

167+
batch_size, num_heads, query_len, _ = query_states.shape
168+
_, num_kv_heads, key_len, _ = key_states.shape
169+
170+
num_queries_per_kv = num_heads // num_kv_heads
171+
172+
attn_mask = create_mask(
173+
attention_bias=attn_bias,
174+
attention_mask=causal_mask if is_causal else None,
175+
batch_size=batch_size,
176+
query_len=query_len,
177+
key_len=key_len,
178+
window_size=window_size,
179+
min_dtype=torch.finfo(query_states.dtype).min,
180+
type="topk"
181+
)
182+
195183
query_states_leaf = query_states
196184
key_states_leaf = key_states
197185
value_states_leaf = value_states
198-
199-
attn_bias, attn_mask = prepare_mask(
200-
query_states,
201-
attn_bias,
202-
causal_mask if is_causal else None,
203-
window_size,
204-
)
205186
attn_bias_leaf = attn_bias
206187
attn_bias_leaf.retain_grad()
207188

@@ -259,29 +240,28 @@ def dynamic_mask_attention_triton(
259240
if triton_sparse_attn_func is None:
260241
raise RuntimeError("Triton implementation not available")
261242

262-
_, num_heads, _, _ = query_states.shape
263-
_, num_kv_heads, _, _ = key_states.shape
243+
batch_size, num_heads, query_len, _ = query_states.shape
244+
_, num_kv_heads, key_len, _ = key_states.shape
245+
264246
num_queries_per_kv = num_heads // num_kv_heads
265247

248+
attn_mask = create_mask(
249+
attention_bias=attn_bias,
250+
attention_mask=causal_mask if is_causal else None,
251+
batch_size=batch_size,
252+
query_len=query_len,
253+
key_len=key_len,
254+
window_size=window_size,
255+
min_dtype=torch.finfo(query_states.dtype).min,
256+
type="topk"
257+
)
258+
266259
query_states_leaf = query_states
267260
key_states_leaf = key_states
268261
value_states_leaf = value_states
269-
270-
attn_bias, attn_mask = prepare_mask(
271-
query_states,
272-
attn_bias,
273-
causal_mask if is_causal else None,
274-
window_size,
275-
)
276262
attn_bias_leaf = attn_bias
277263
attn_bias_leaf.retain_grad()
278264

279-
# Repeat KV for multi-head attention (GQA support)
280-
key_states = repeat_kv(key_states, num_queries_per_kv)
281-
value_states = repeat_kv(value_states, num_queries_per_kv)
282-
attn_mask = repeat_kv(attn_mask, num_queries_per_kv)
283-
attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv)
284-
285265
# Ensure correct data types and memory layout for Triton function
286266
query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim]
287267
key_states = key_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim]
@@ -333,30 +313,38 @@ def dynamic_mask_attention_flex(
333313
if flex_sparse_attn_func is None:
334314
raise RuntimeError("Flex Attention implementation not available")
335315

336-
_, num_heads, _, _ = query_states.shape
337-
_, num_kv_heads, _, _ = key_states.shape
316+
batch_size, num_heads, query_len, _ = query_states.shape
317+
_, num_kv_heads, key_len, _ = key_states.shape
318+
338319
num_queries_per_kv = num_heads // num_kv_heads
339320

340-
attn_bias, attn_mask = prepare_mask(
341-
query_states,
342-
attn_bias,
343-
causal_mask if is_causal else None,
344-
window_size,
321+
attn_mask = create_mask(
322+
attention_bias=attn_bias,
323+
attention_mask=causal_mask if is_causal else None,
324+
batch_size=batch_size,
325+
query_len=query_len,
326+
key_len=key_len,
327+
window_size=window_size,
328+
min_dtype=torch.finfo(query_states.dtype).min,
329+
type="topk"
345330
)
346-
attn_bias.retain_grad()
331+
332+
query_states_leaf = query_states
333+
key_states_leaf = key_states
334+
value_states_leaf = value_states
335+
attn_bias_leaf = attn_bias
336+
attn_bias_leaf.retain_grad()
347337

348338
# Repeat KV for multi-head attention (GQA support)
349339
key_states = repeat_kv(key_states, num_queries_per_kv)
350340
value_states = repeat_kv(value_states, num_queries_per_kv)
351-
attn_mask = repeat_kv(attn_mask, num_queries_per_kv)
341+
attn_mask = repeat_kv(attn_mask, num_queries_per_kv) if attn_mask is not None else None
352342
attn_bias = repeat_kv(attn_bias, num_queries_per_kv)
353343

354344
# Ensure correct data types and memory layout for Flex function
355345
query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim]
356346
key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim]
357347
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim]
358-
attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]
359-
attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]
360348

361349
# Call the Flex Attention implementation
362350
attn_outputs = flex_sparse_attn_func(
@@ -372,7 +360,7 @@ def dynamic_mask_attention_flex(
372360
# Backward pass
373361
attn_outputs.sum().backward()
374362

375-
return attn_outputs, query_states.grad, key_states.grad, value_states.grad, attn_bias.grad
363+
return attn_outputs, query_states_leaf.grad, key_states_leaf.grad, value_states_leaf.grad, attn_bias_leaf.grad
376364

377365

378366
def analyze_differences(original_result, cuda_result, accuracy_threshold=0.95):
@@ -609,7 +597,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
609597
device=device, dtype=dtype, requires_grad=True
610598
)
611599
attn_bias = torch.randn(
612-
batch_size, num_kv_heads, query_len, key_len,
600+
batch_size, num_kv_heads, 1, key_len,
613601
device=device, dtype=torch.bfloat16
614602
)
615603
cache_position = torch.arange(key_len - query_len, key_len, device=device)
@@ -843,7 +831,7 @@ def test_triton_backward_equivalence(accuracy_threshold=0.95):
843831
device=device, dtype=dtype, requires_grad=True
844832
)
845833
attn_bias = torch.randn(
846-
batch_size, num_kv_heads, query_len, key_len,
834+
batch_size, num_kv_heads, 1, key_len,
847835
device=device, dtype=torch.bfloat16
848836
)
849837
cache_position = torch.arange(key_len - query_len, key_len, device=device)

0 commit comments

Comments
 (0)