Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Summary

  • This update enhances numerical stability in sparse attention mechanisms by introducing support for sink auxiliary logits.

Root Cause

  • The previous implementation lacked proper handling of auxiliary logits, leading to instability during forward and backward passes.

Changes

  • Added support for sink auxiliary logits in forward and backward sparse attention kernels.
  • Improved softmax normalization to account for auxiliary scaling, ensuring consistent gradient propagation.
  • Updated documentation to clarify the necessity of auxiliary buffers.

Reproduction

  • Test the sparse attention functionality with various configurations of sink auxiliary logits to observe improved stability.

Tests

  • New tests added to validate the behavior of sink auxiliary logits in both forward and backward passes.

Compatibility

  • No breaking changes introduced; existing functionality remains intact.

Checklist

  • Linked issue provided
  • Adds or updates tests
  • Updates docs if needed
  • No perf regressions

Extends forward/backward sparse attention kernels to include optional sink auxiliary logits, improving numerical stability when modeling sink tokens. Adds dedicated backward reduction for the auxiliary scalar and wires the tensor plus gradient through the autograd wrapper and functional API.
Enables optional auxiliary scaling tensor to flow through all forward/backward sparse attention paths, validating shape/dtype and ensuring gradients propagate when s_aux requires_grad.
Adds optional sink logits tensors to fwd/bwd paths, defaulting to $-\infty$ values to keep kernel parameters valid and expose gradients as outputs
Disables the seqlen–ngroups swap when sink logits are used so the head dimension remains aligned with provided data
Introduces an auxiliary max reference for normalize_softmax_lse so chunked accumulations rescale into a common frame, keeping log-sum-exp finite when combining scores.
Clarifies that kernels always expect valid saux/dsaux pointers so callers ignore the feature by wiring sentinel tensors, preventing undefined memory access when the sink logits are unused
Reads the per-head auxiliary scale from params and feeds it into the softmax normalization so the LSE uses the precomputed adjustment.
Accounts for the dropped sink term by computing its shared ds contribution per block and atomically accumulating it, aligning backward gradients with the forward sink probability handling
Clarifies sink softmax comments and simplifies backwards description.
Ensures the optional auxiliary logits path flows through the contiguous inputs and gradient returns in a consistent argument order.
Converts non-float32 sink logits to float32 instead of rejecting them, keeping attention paths usable.
Documents the sink tensor in both APIs and makes it optional in the varlen signature for clarity.
Drops misleading comments implying the auxiliary sink buffers are optional so the documentation matches actual kernel expectations
Adds the precomputed auxiliary scale from params to the softmax normalization step so attention rows reuse the per-head auxiliary state instead of recomputing it
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enhances numerical stability in sparse attention mechanisms by introducing support for sink auxiliary logits (s_aux). The sink auxiliary logits are added as an extra "sink" position in the softmax computation, where the sink probability is computed but dropped from the output, allowing other attention weights to be reduced for improved numerical stability.

Key Changes

  • Added s_aux parameter (shape: (nheads,)) to forward and backward attention kernels across Triton and CUDA implementations
  • Modified softmax normalization to rescale into an extended max reference frame that includes the sink auxiliary logit
  • Updated backward pass to compute gradients for sink auxiliary logits using atomic additions

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
flash_sparse_attn/flash_sparse_attn_triton.py Added HAS_SAUX flag, sink auxiliary logit handling in forward kernel, new _bwd_saux_kernel for gradient computation, and updated function signatures
flash_sparse_attn/flash_sparse_attn_interface.py Updated forward/backward function signatures to accept and return s_aux/ds_aux, added dtype validation and conversion for s_aux
csrc/flash_sparse_attn/src/softmax.h Modified normalize_softmax_lse to accept s_aux parameter and rescale softmax computation into extended max reference frame
csrc/flash_sparse_attn/src/flash_fwd_kernel.h Updated forward kernel to load s_aux per head and pass to softmax normalization
csrc/flash_sparse_attn/src/flash_bwd_kernel.h Added sink gradient computation using atomic adds to dsaux_ptr
csrc/flash_sparse_attn/src/flash.h Added saux_ptr and dsaux_ptr fields to parameter structs
csrc/flash_sparse_attn/flash_api.cpp Added s_aux parameter handling, validation, and default tensor creation; disabled seqlenq-ngroups swap when s_aux is provided; updated return values to include ds_aux

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@@ -408,6 +424,34 @@ def _bwd_preprocess_do_o_dot(
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
Copy link

Copilot AI Dec 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tl.store operation for Delta should include a mask to avoid writing out-of-bounds values when offs_m >= seqlen_q. While the loaded values are masked with other=0.0, the store operation should also be masked to prevent potential memory corruption in the rounded buffer region. Add mask=offs_m < seqlen_q to the store operation.

Suggested change
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta, mask=offs_m < seqlen_q)

Copilot uses AI. Check for mistakes.
Comment on lines +736 to +737
const float s_aux = reinterpret_cast<const float *>(params.saux_ptr)[bidh];
Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, s_aux);
Copy link

Copilot AI Dec 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward kernel unconditionally dereferences params.saux_ptr without checking if it's null. According to the API code in flash_api.cpp, when saux_ is not provided, a tensor filled with -inf is created. However, the kernel should either check for null or the comment in flash.h should be updated to clarify that this pointer must always be valid (never null). The current comment "filled with -inf when the feature is not used" suggests it's always allocated, which should be made more explicit.

Copilot uses AI. Check for mistakes.
Comment on lines +686 to +697
const float s_aux = reinterpret_cast<const float *>(params.saux_ptr)[bidh];
float ds_aux_partial = 0.f;
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
// If lse == inf (OOB row), exp(s_aux - inf) = 0 so this contributes 0.
ds_aux_partial += -expf(s_aux - float(lse(mi))) * float(dP_sum(mi));
}
FLASH_NAMESPACE::SumOp<float> sum_op;
ds_aux_partial = FLASH_NAMESPACE::Allreduce<32>::run(ds_aux_partial, sum_op);
if ((tidx & 31) == 0) {
atomicAdd(&reinterpret_cast<float *>(params.dsaux_ptr)[bidh], ds_aux_partial);
}
Copy link

Copilot AI Dec 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the forward kernel, the backward kernel unconditionally dereferences params.saux_ptr and params.dsaux_ptr without null checks. The code relies on these pointers always being valid (allocated to -inf tensors when the feature is not used). This should be documented more clearly or the code should add explicit null checks for safety.

Copilot uses AI. Check for mistakes.
dbias_expanded = dbias_expanded.sum(dim=0, keepdim=True)
dbias.copy_(dbias_expanded)
return dq, dk, dv, dbias if has_bias else None
return dq, dk, dv, dbias if has_bias else None, (ds_aux.to(s_aux.dtype) if has_saux else None)
Copy link

Copilot AI Dec 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the backward function, when has_saux is True but s_aux is not a tensor (which could happen if it's stored as an empty tensor), calling s_aux.dtype will fail. The code should check if s_aux.numel() > 0 before calling .dtype, similar to how it's done elsewhere in the codebase. Consider changing this to: ds_aux.to(s_aux.dtype) if (has_saux and s_aux.numel() > 0) else None.

Suggested change
return dq, dk, dv, dbias if has_bias else None, (ds_aux.to(s_aux.dtype) if has_saux else None)
return dq, dk, dv, dbias if has_bias else None, (
ds_aux.to(s_aux.dtype) if (has_saux and s_aux.numel() > 0) else None
)

Copilot uses AI. Check for mistakes.
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float max_scores = row_max(mi);
float max_ext = max_scores > s_aux ? max_scores : s_aux;
Copy link

Copilot AI Dec 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition max_scores > s_aux ? max_scores : s_aux could produce incorrect results when either value is NaN, since NaN comparisons always return false. When dealing with numerical stability in softmax computations, it's safer to use fmaxf(max_scores, s_aux) which handles NaN values according to IEEE 754 semantics (returns the non-NaN value if one operand is NaN).

Suggested change
float max_ext = max_scores > s_aux ? max_scores : s_aux;
float max_ext = fmaxf(max_scores, s_aux);

Copilot uses AI. Check for mistakes.
Comment on lines +493 to +499
at::Tensor saux;
if (saux_.has_value()) {
saux = saux_.value();
TORCH_CHECK(saux.dtype() == torch::kFloat32, "s_aux must have dtype float32");
CHECK_DEVICE(saux);
TORCH_CHECK(saux.dim() == 1, "s_aux must be 1D with shape (num_heads,)");
TORCH_CHECK(saux.size(0) == num_heads, "s_aux must have shape (num_heads,)");
Copy link

Copilot AI Dec 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation check for s_aux.size(0) == num_heads occurs after num_heads may have been modified by the seqlenq_ngroups_swapped logic (line 490). However, this is actually correct because when has_saux is true, seqlenq_ngroups_swapped is forced to false (line 460), so num_heads is never modified. Still, for code clarity and to prevent future bugs, consider validating the s_aux shape against the original num_heads value before the swap logic, or add a comment explaining why this order is safe.

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit a5c2a65 into main Dec 27, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants