-
Notifications
You must be signed in to change notification settings - Fork 39
[FEATURE SUPPORT] Triton special compact dynamic-mask attention: 1.6× faster fwd+bwd, numerically equivalent #206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Enables fused Triton forward/backward paths for dynamic masked attention to reduce padding overhead and deliver faster windowed attention execution.
Introduces reusable top-k extraction on the bias tensor to simplify downstream mask logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a specialized Triton implementation for flash dynamic masked attention along with a utility function to extract top-k indices from attention bias. The implementation introduces a gather-based approach where attention is computed only on a subset of key-value pairs selected by top-k indices.
Key changes:
- Added
topk_indicesutility function to extract and sort top-k indices from attention bias - Implemented a new Triton-based flash attention variant that uses gathered K/V/bias values
- Added preprocessing, forward, and backward kernels for the specialized implementation
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 9 comments.
| File | Description |
|---|---|
| flash_dmattn/utils/mask.py | Added topk_indices function to compute sorted top-k indices from attention bias |
| flash_dmattn/flash_dmattn_triton_special.py | New file implementing specialized Triton kernels for flash dynamic masked attention with gather-based optimization |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| (batch_size, num_kv_heads, key_len). | ||
| window_size (int): The number of top elements to consider for the mask. | ||
| **kwargs: Additional keyword arguments. | ||
Copilot
AI
Nov 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trailing whitespace should be removed.
| mask=valid_idx, | ||
| other=0.0, | ||
| ) | ||
|
|
Copilot
AI
Nov 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trailing whitespace should be removed.
| mask=(start_n + offs_n) < window_size, | ||
| other=0.0, | ||
| ) | ||
|
|
Copilot
AI
Nov 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trailing whitespace should be removed.
|
|
||
| # Compute dp | ||
| dp = tl.dot(do, tl.trans(v)) | ||
|
|
Copilot
AI
Nov 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trailing whitespace should be removed.
| ) | ||
|
|
||
| # We could have padded the head dimension | ||
| dq = dq[..., : do.shape[-1]] |
Copilot
AI
Nov 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trailing whitespace should be removed.
| dq = dq[..., : do.shape[-1]] | |
| dq = dq[..., : do.shape[-1]] |
| dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) | ||
| db = tl.zeros([BLOCK_N], dtype=tl.float32) | ||
|
|
||
| # Load k and v, them will stay in SRAM throughout |
Copilot
AI
Nov 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'them' to 'they'.
| # Load k and v, them will stay in SRAM throughout | |
| # Load k and v, they will stay in SRAM throughout |
| acc_s += tl.where(m, 0, float("-inf")) | ||
|
|
||
| # Compute p | ||
| m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) |
Copilot
AI
Nov 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable name lse_i is initialized with negative infinity at line 359. When all mask elements are False and any_active prevents computation, lse_i remains negative infinity throughout. At line 516, this would store negative infinity values. Consider adding a comment explaining this edge case behavior or adding a guard to handle empty masks explicitly.
| m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) | |
| # Guard against the case where all mask elements are False and lse_i remains -inf. | |
| if tl.all(~m): | |
| # All elements are masked out; set m_ij to 0 (or another safe value). | |
| m_ij = tl.zeros([acc_s.shape[0]], dtype=acc_s.dtype) | |
| else: | |
| m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) |
|
|
||
| lse_i = tl.load(LSE + offs_m_curr) | ||
| # p = tl.exp(acc_s - lse_i[:, None]) | ||
| p = tl.exp(acc_s - tl.where(lse_i > float("-inf"), lse_i, 0.0)[:, None]) |
Copilot
AI
Nov 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This guard handles the case when lse_i is negative infinity, but the forward kernel at line 463 doesn't have the same protection. Consider applying consistent handling of negative infinity LSE values in both forward and backward passes, or document why they differ.
| assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" | ||
| assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" | ||
| seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 | ||
| seqlen_k_rounded = math.ceil(seqlen_k / 128) * 128 |
Copilot
AI
Nov 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable seqlen_k_rounded is not used.
| seqlen_k_rounded = math.ceil(seqlen_k / 128) * 128 | |
| # seqlen_k_rounded = math.ceil(seqlen_k / 128) * 128 # Removed unused variable |
Summary
Design
lse), and accumulate outputs. Skip fully-inactive tiles.attn_indices.h_h_k_ratio = nheads // nheads_kin both forward and backward.lsetracking identical to the baseline DMA. Causal masking is supported and pre-applied into CuM.head_dim ≤ 128, dtype in {fp16, bf16},attn_indicesis int64 and must be valid for the chosenwindow_size.Changes
flash_dmattn.flash_dmattn_triton_special.triton_dmattn_func(query, key, value, attn_bias, attn_indices, is_causal=False, softmax_scale=None)._fwd_preprocess: gather K/V/B into CuK/CuV/CuB and construct CuM with row/col/causal masking._fwd_kernel: streaming softmax over compact tiles with stablelse._bwd_preprocess_do_o_dot: compute per-row Delta._bwd_kernel+_bwd_kernel_one_col_block: column-block backward with fp32 accumulators, then scatter back via indices.triton_dmattn_funcconvenience entrypoint.attn_indicesare valid indices intokey_len.Implementation notes
CuM.attn_indicesshould be in [0, key_len); generator (e.g.,topk_indices) guarantees validity. If external indices are used, optional guards can be added during scatter to drop invalid entries.head_dim ≤ 128; extending beyond may require additional kernel variants.Tests
Documentation
triton_dmattn_functo the English API docs, including input shapes, dtype constraints, and notes onattn_indices.attn_indices(e.g., viatopk_indices) and switch between baseline and special Triton path.head_dim ≤ 128, bf16/fp16).Checklist