Skip to content

Commit b500c36

Browse files
authored
[FEATURE SUPPORT] Variable-Length Attention with Padding-Free Execution
2 parents 291dbd5 + 587f0a6 commit b500c36

19 files changed

+1785
-623
lines changed

.github/ISSUE_TEMPLATE/performance_issue.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name: Performance issue
2-
description: Report performance problems or optimisation opportunities
2+
description: Report performance problems or optimization opportunities
33
title: "[PERFORMANCE] "
44
labels:
55
- performance

.github/PULL_REQUEST_TEMPLATE/performance_optimization.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ body:
77
- type: markdown
88
attributes:
99
value: |
10-
Document the optimisation, methodology, and results so reviewers can validate gains and correctness.
10+
Document the optimization, methodology, and results so reviewers can validate gains and correctness.
1111
- type: textarea
1212
id: summary
1313
attributes:

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ output = flash_dmattn_func(
195195
attn_mask=attention_mask,
196196
attn_bias=attention_bias,
197197
is_causal=True,
198-
scale=1.0/math.sqrt(head_dim),
198+
softmax_scale=1.0/math.sqrt(head_dim),
199199
)
200200

201201
print(f"Output shape: {output.shape}") # [1, 256, 2, 64]
@@ -216,7 +216,7 @@ output = flash_dmattn_func(
216216
attn_mask=attention_mask,
217217
attn_bias=attention_bias,
218218
is_causal=True,
219-
scale=1.0/math.sqrt(head_dim)
219+
softmax_scale=1.0/math.sqrt(head_dim)
220220
)
221221

222222
# Backward pass

README_zh.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ output = flash_dmattn_func(
195195
attn_mask=attention_mask,
196196
attn_bias=attention_bias,
197197
is_causal=True,
198-
scale=1.0/math.sqrt(head_dim),
198+
softmax_scale=1.0/math.sqrt(head_dim),
199199
)
200200

201201
print(f"输出形状: {output.shape}") # [1, 256, 2, 64]
@@ -216,7 +216,7 @@ output = flash_dmattn_func(
216216
attn_mask=attention_mask,
217217
attn_bias=attention_bias,
218218
is_causal=True,
219-
scale=1.0/math.sqrt(head_dim)
219+
softmax_scale=1.0/math.sqrt(head_dim)
220220
)
221221

222222
# 反向传播

benchmarks/backward_equivalence.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def dynamic_mask_attention_cuda(
266266
attn_mask=attn_mask, # mask: [batch, num_kv_heads, query_len, key_len]
267267
attn_bias=attn_bias, # bias: [batch, num_kv_heads, query_len, key_len]
268268
is_causal=is_causal, # causal masking
269-
scale=scaling, # scaling factor
269+
softmax_scale=scaling, # scaling factor
270270
softcap=0.0,
271271
deterministic=False,
272272
return_attn_probs=False
@@ -351,7 +351,7 @@ def dynamic_mask_attention_triton(
351351
attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
352352
attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
353353
is_causal=is_causal, # causal masking
354-
scale=scaling # scaling factor
354+
softmax_scale=scaling # scaling factor
355355
)
356356

357357
# Backward pass
@@ -424,7 +424,7 @@ def dynamic_mask_attention_flex(
424424
attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len]
425425
attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len]
426426
is_causal=is_causal, # is_causal: whether to apply causal masking
427-
scale=scaling # scaling factor
427+
softmax_scale=scaling # scaling factor
428428
)
429429

430430
# Backward pass

benchmarks/backward_performance.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def scaled_dot_product_attention_backward(
183183
key_states, # [batch, num_kv_heads, key_len, head_dim]
184184
value_states, # [batch, num_kv_heads, key_len, head_dim]
185185
attn_mask=causal_mask,
186-
scale=scaling,
186+
softmax_scale=scaling,
187187
# is_causal=is_causal if query_len == key_len else False,
188188
enable_gqa=True
189189
)
@@ -262,7 +262,7 @@ def dynamic_mask_attention_backward_cuda(
262262
attn_mask=attn_mask, # mask: [batch, num_kv_heads, query_len, key_len]
263263
attn_bias=attn_bias, # bias: [batch, num_kv_heads, query_len, key_len]
264264
is_causal=is_causal, # causal masking
265-
scale=scaling, # scaling factor
265+
softmax_scale=scaling, # scaling factor
266266
softcap=0.0,
267267
deterministic=False,
268268
return_attn_probs=False
@@ -351,7 +351,7 @@ def dynamic_mask_attention_backward_triton(
351351
attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
352352
attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
353353
is_causal=is_causal, # causal masking
354-
scale=scaling # scaling factor
354+
softmax_scale=scaling # scaling factor
355355
)
356356

357357
torch.cuda.synchronize()
@@ -433,7 +433,7 @@ def dynamic_mask_attention_backward_flex(
433433
attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len]
434434
attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len]
435435
is_causal=is_causal, # is_causal: whether to apply causal masking
436-
scale=scaling # scaling factor
436+
softmax_scale=scaling # scaling factor
437437
)
438438

439439
torch.cuda.synchronize()

benchmarks/forward_equivalence.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def dynamic_mask_attention_cuda(
253253
attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len]
254254
attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len]
255255
is_causal=is_causal,
256-
scale=scaling,
256+
softmax_scale=scaling,
257257
softcap=0.0,
258258
deterministic=True,
259259
return_attn_probs=return_softmax
@@ -329,7 +329,7 @@ def dynamic_mask_attention_triton(
329329
attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
330330
attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
331331
is_causal=is_causal, # causal masking
332-
scale=scaling # scaling factor
332+
softmax_scale=scaling # scaling factor
333333
)
334334

335335
return attn_outputs # [batch, query_len, num_heads, head_dim]
@@ -398,7 +398,7 @@ def dynamic_mask_attention_flex(
398398
attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len]
399399
attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len]
400400
is_causal=is_causal, # is_causal: whether to apply causal masking
401-
scale=scaling # scaling factor
401+
softmax_scale=scaling # scaling factor
402402
)
403403

404404
return attn_outputs # [batch, query_len, num_heads, head_dim]

benchmarks/forward_performance.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def scaled_dot_product_attention_cuda(
186186
key_states,
187187
value_states,
188188
attn_mask=causal_mask,
189-
scale=scaling,
189+
softmax_scale=scaling,
190190
# is_causal=is_causal if query_len == key_len else False,
191191
enable_gqa=True
192192
)
@@ -262,7 +262,7 @@ def dynamic_mask_attention_cuda(
262262
attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len]
263263
attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len]
264264
is_causal=is_causal,
265-
scale=scaling,
265+
softmax_scale=scaling,
266266
softcap=0.0,
267267
deterministic=False,
268268
return_attn_probs=return_softmax
@@ -348,7 +348,7 @@ def dynamic_mask_attention_triton(
348348
attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k]
349349
attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k]
350350
is_causal=is_causal, # causal masking
351-
scale=scaling # scaling factor
351+
softmax_scale=scaling # scaling factor
352352
)
353353

354354
torch.cuda.synchronize()
@@ -427,7 +427,7 @@ def dynamic_mask_attention_flex(
427427
attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len]
428428
attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len]
429429
is_causal=is_causal, # is_causal: whether to apply causal masking
430-
scale=scaling # scaling factor
430+
softmax_scale=scaling # scaling factor
431431
)
432432

433433
torch.cuda.synchronize()

0 commit comments

Comments
 (0)