-
Notifications
You must be signed in to change notification settings - Fork 49
Improve numerical stability in sparse attention with sink auxiliary logits #220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Extends forward/backward sparse attention kernels to include optional sink auxiliary logits, improving numerical stability when modeling sink tokens. Adds dedicated backward reduction for the auxiliary scalar and wires the tensor plus gradient through the autograd wrapper and functional API.
Enables optional auxiliary scaling tensor to flow through all forward/backward sparse attention paths, validating shape/dtype and ensuring gradients propagate when s_aux requires_grad.
Adds optional sink logits tensors to fwd/bwd paths, defaulting to $-\infty$ values to keep kernel parameters valid and expose gradients as outputs Disables the seqlen–ngroups swap when sink logits are used so the head dimension remains aligned with provided data
Introduces an auxiliary max reference for normalize_softmax_lse so chunked accumulations rescale into a common frame, keeping log-sum-exp finite when combining scores.
Clarifies that kernels always expect valid saux/dsaux pointers so callers ignore the feature by wiring sentinel tensors, preventing undefined memory access when the sink logits are unused
Reads the per-head auxiliary scale from params and feeds it into the softmax normalization so the LSE uses the precomputed adjustment.
Accounts for the dropped sink term by computing its shared ds contribution per block and atomically accumulating it, aligning backward gradients with the forward sink probability handling
Clarifies sink softmax comments and simplifies backwards description. Ensures the optional auxiliary logits path flows through the contiguous inputs and gradient returns in a consistent argument order.
Converts non-float32 sink logits to float32 instead of rejecting them, keeping attention paths usable. Documents the sink tensor in both APIs and makes it optional in the varlen signature for clarity.
Drops misleading comments implying the auxiliary sink buffers are optional so the documentation matches actual kernel expectations
Adds the precomputed auxiliary scale from params to the softmax normalization step so attention rows reuse the per-head auxiliary state instead of recomputing it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR enhances numerical stability in sparse attention mechanisms by introducing support for sink auxiliary logits (s_aux). The sink auxiliary logits are added as an extra "sink" position in the softmax computation, where the sink probability is computed but dropped from the output, allowing other attention weights to be reduced for improved numerical stability.
Key Changes
- Added
s_auxparameter (shape:(nheads,)) to forward and backward attention kernels across Triton and CUDA implementations - Modified softmax normalization to rescale into an extended max reference frame that includes the sink auxiliary logit
- Updated backward pass to compute gradients for sink auxiliary logits using atomic additions
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
flash_sparse_attn/flash_sparse_attn_triton.py |
Added HAS_SAUX flag, sink auxiliary logit handling in forward kernel, new _bwd_saux_kernel for gradient computation, and updated function signatures |
flash_sparse_attn/flash_sparse_attn_interface.py |
Updated forward/backward function signatures to accept and return s_aux/ds_aux, added dtype validation and conversion for s_aux |
csrc/flash_sparse_attn/src/softmax.h |
Modified normalize_softmax_lse to accept s_aux parameter and rescale softmax computation into extended max reference frame |
csrc/flash_sparse_attn/src/flash_fwd_kernel.h |
Updated forward kernel to load s_aux per head and pass to softmax normalization |
csrc/flash_sparse_attn/src/flash_bwd_kernel.h |
Added sink gradient computation using atomic adds to dsaux_ptr |
csrc/flash_sparse_attn/src/flash.h |
Added saux_ptr and dsaux_ptr fields to parameter structs |
csrc/flash_sparse_attn/flash_api.cpp |
Added s_aux parameter handling, validation, and default tensor creation; disabled seqlenq-ngroups swap when s_aux is provided; updated return values to include ds_aux |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -408,6 +424,34 @@ def _bwd_preprocess_do_o_dot( | |||
| tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) | |||
Copilot
AI
Dec 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tl.store operation for Delta should include a mask to avoid writing out-of-bounds values when offs_m >= seqlen_q. While the loaded values are masked with other=0.0, the store operation should also be masked to prevent potential memory corruption in the rounded buffer region. Add mask=offs_m < seqlen_q to the store operation.
| tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) | |
| tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta, mask=offs_m < seqlen_q) |
| const float s_aux = reinterpret_cast<const float *>(params.saux_ptr)[bidh]; | ||
| Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, s_aux); |
Copilot
AI
Dec 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward kernel unconditionally dereferences params.saux_ptr without checking if it's null. According to the API code in flash_api.cpp, when saux_ is not provided, a tensor filled with -inf is created. However, the kernel should either check for null or the comment in flash.h should be updated to clarify that this pointer must always be valid (never null). The current comment "filled with -inf when the feature is not used" suggests it's always allocated, which should be made more explicit.
| const float s_aux = reinterpret_cast<const float *>(params.saux_ptr)[bidh]; | ||
| float ds_aux_partial = 0.f; | ||
| #pragma unroll | ||
| for (int mi = 0; mi < size(lse); ++mi) { | ||
| // If lse == inf (OOB row), exp(s_aux - inf) = 0 so this contributes 0. | ||
| ds_aux_partial += -expf(s_aux - float(lse(mi))) * float(dP_sum(mi)); | ||
| } | ||
| FLASH_NAMESPACE::SumOp<float> sum_op; | ||
| ds_aux_partial = FLASH_NAMESPACE::Allreduce<32>::run(ds_aux_partial, sum_op); | ||
| if ((tidx & 31) == 0) { | ||
| atomicAdd(&reinterpret_cast<float *>(params.dsaux_ptr)[bidh], ds_aux_partial); | ||
| } |
Copilot
AI
Dec 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the forward kernel, the backward kernel unconditionally dereferences params.saux_ptr and params.dsaux_ptr without null checks. The code relies on these pointers always being valid (allocated to -inf tensors when the feature is not used). This should be documented more clearly or the code should add explicit null checks for safety.
| dbias_expanded = dbias_expanded.sum(dim=0, keepdim=True) | ||
| dbias.copy_(dbias_expanded) | ||
| return dq, dk, dv, dbias if has_bias else None | ||
| return dq, dk, dv, dbias if has_bias else None, (ds_aux.to(s_aux.dtype) if has_saux else None) |
Copilot
AI
Dec 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the backward function, when has_saux is True but s_aux is not a tensor (which could happen if it's stored as an empty tensor), calling s_aux.dtype will fail. The code should check if s_aux.numel() > 0 before calling .dtype, similar to how it's done elsewhere in the codebase. Consider changing this to: ds_aux.to(s_aux.dtype) if (has_saux and s_aux.numel() > 0) else None.
| return dq, dk, dv, dbias if has_bias else None, (ds_aux.to(s_aux.dtype) if has_saux else None) | |
| return dq, dk, dv, dbias if has_bias else None, ( | |
| ds_aux.to(s_aux.dtype) if (has_saux and s_aux.numel() > 0) else None | |
| ) |
| for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { | ||
| float sum = row_sum(mi); | ||
| float max_scores = row_max(mi); | ||
| float max_ext = max_scores > s_aux ? max_scores : s_aux; |
Copilot
AI
Dec 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition max_scores > s_aux ? max_scores : s_aux could produce incorrect results when either value is NaN, since NaN comparisons always return false. When dealing with numerical stability in softmax computations, it's safer to use fmaxf(max_scores, s_aux) which handles NaN values according to IEEE 754 semantics (returns the non-NaN value if one operand is NaN).
| float max_ext = max_scores > s_aux ? max_scores : s_aux; | |
| float max_ext = fmaxf(max_scores, s_aux); |
| at::Tensor saux; | ||
| if (saux_.has_value()) { | ||
| saux = saux_.value(); | ||
| TORCH_CHECK(saux.dtype() == torch::kFloat32, "s_aux must have dtype float32"); | ||
| CHECK_DEVICE(saux); | ||
| TORCH_CHECK(saux.dim() == 1, "s_aux must be 1D with shape (num_heads,)"); | ||
| TORCH_CHECK(saux.size(0) == num_heads, "s_aux must have shape (num_heads,)"); |
Copilot
AI
Dec 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The validation check for s_aux.size(0) == num_heads occurs after num_heads may have been modified by the seqlenq_ngroups_swapped logic (line 490). However, this is actually correct because when has_saux is true, seqlenq_ngroups_swapped is forced to false (line 460), so num_heads is never modified. Still, for code clarity and to prevent future bugs, consider validating the s_aux shape against the original num_heads value before the swap logic, or add a comment explaining why this order is safe.
Summary
Root Cause
Changes
Reproduction
Tests
Compatibility
Checklist