Improves softmax stability with log2 scaling#228
Conversation
Uses log2-based scaling to align exponentials and log-sum-exp, improving numerical stability.
Improves reuse by extracting the inner loop and aligning softmax scaling with log2 math to simplify masking paths and reduce duplicated logic
There was a problem hiding this comment.
Pull request overview
This PR refactors the Triton flash-attention forward softmax path to use base-2 exponentiation (exp2) with a log2(e) scaling factor, aiming to improve numerical stability for extreme logits while simplifying inner-loop code reuse.
Changes:
- Introduces a shared
_fwd_inner_base_kernel()helper in the Triton forward kernel and switches kernel inputs fromsoftmax_scaletosoftmax_scale_log2. - Updates
activations.online_softmax()to compute probabilities/rescaling viatl.math.exp2(... * scale_log2). - Updates
activations.finalize()to acceptscale_log2and adjust LSE computation accordingly.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
flash_sparse_attn/ops/triton/flash_fwd.py |
Refactors the forward kernel loop body into a helper and passes softmax_scale_log2 into the online softmax/finalize path. |
flash_sparse_attn/ops/triton/activations.py |
Switches online softmax to exp2 with a new scale_log2 parameter and modifies finalize/LSE computation. |
Comments suppressed due to low confidence (1)
flash_sparse_attn/ops/triton/activations.py:25
online_softmax()now requires ascale_log2argument, but its docstring still documents onlyCHECK_INFand doesn't explain the expected units (e.g.,softmax_scale * log2(e)) or how it relates to the returnedrow_max/row_sum. Please update the docstring to includescale_log2so callers can pass the correct value.
def online_softmax(
acc_s,
row_max,
row_sum,
scale_log2,
CHECK_INF: tl.constexpr,
):
"""
Apply online softmax to acc_s, and update row_max and row_sum.
:param acc_s: Attention scores tensor of shape [BLOCK_M, BLOCK_N].
:param row_max: Current maximum values per row of shape [BLOCK_M], init to -inf.
:param row_sum: Current sum values per row of shape [BLOCK_M], init to 0.
:param CHECK_INF: Boolean flag indicating if -inf row_max should be clamped to 0.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ln2 = tl.math.log2(2.0) | ||
| scale = scale_log2 * ln2 | ||
| lse = tl.where( | ||
| acc_o_is_zero_or_nan, | ||
| float("-inf"), | ||
| row_max * scale + tl.log(row_sum), | ||
| ) |
There was a problem hiding this comment.
finalize() computes ln2 = tl.math.log2(2.0) (which is 1), so scale = scale_log2 * ln2 leaves scale == scale_log2. Since row_sum is later logged with tl.log (natural log), the row_max term must be multiplied by the natural-exp scale (softmax_scale), i.e. scale_log2 * ln(2), not scale_log2. As-is, lse will be systematically incorrect and will break downstream logic that expects natural-log LSE (e.g., split-KV combine).
Simplifies exp2 usage and adjusts log scaling to improve numerical correctness in output normalization and log-sum-exp computation
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
Comments suppressed due to low confidence (2)
flash_sparse_attn/ops/triton/activations.py:26
- The
online_softmaxdocstring/param docs don’t mention the newly addedscale_log2parameter, which makes the API change harder to understand for future maintainers. Please document whatscale_log2represents (e.g.,softmax_scale * log2(e)) and its expected dtype/range.
"""
Apply online softmax to acc_s, and update row_max and row_sum.
:param acc_s: Attention scores tensor of shape [BLOCK_M, BLOCK_N].
:param row_max: Current maximum values per row of shape [BLOCK_M], init to -inf.
:param row_sum: Current sum values per row of shape [BLOCK_M], init to 0.
:param CHECK_INF: Boolean flag indicating if -inf row_max should be clamped to 0.
flash_sparse_attn/ops/triton/activations.py:63
- The
finalizedocstring doesn’t describe the newscale_log2parameter, even though it’s now required. Please update the docstring to include whatscale_log2is and why it’s needed for computinglse.
"""
Finalize online softmax by computing output scale and logsumexp.
:param row_max: Final maximum values per row of shape [BLOCK_M].
:param row_sum: Final sum values per row of shape [BLOCK_M].
:param final_scale: Scaling factor to be applied to the output.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| acc_s, | ||
| row_max, | ||
| row_sum, | ||
| scale_log2, |
There was a problem hiding this comment.
activations.online_softmax now requires a new scale_log2 argument, but there are existing call sites (e.g. in flash_sparse_fwd.py) that still call it without this parameter. This will break compilation/runtime for the Triton sparse forward path unless those call sites are updated or scale_log2 is given a backward-compatible default.
| scale_log2, | |
| scale_log2=1.0, |
| def finalize( | ||
| row_max, | ||
| row_sum, | ||
| scale_log2, | ||
| final_scale, | ||
| ): |
There was a problem hiding this comment.
activations.finalize now requires scale_log2, but there are existing call sites (e.g. in flash_sparse_fwd.py) that still call finalize(row_max, row_sum, final_scale=...) without the new argument. Please update those call sites or keep finalize backward compatible, otherwise the sparse Triton forward will fail.
| # Apply mask | ||
| acc_s = mask.apply_mask( | ||
| acc_s=acc_s, | ||
| for n_block in tl.range(n_block_max - 1, n_block_min_causal_local - 1, -1): |
There was a problem hiding this comment.
tl.range(...) is used for the n_block loops here, but the rest of the Triton kernels in this repo use Python range(...) / tl.static_range(...). triton.language does not consistently provide a range helper across Triton versions; if tl.range isn't available in the project's supported Triton/PyTorch-Triton versions, this will fail at compile time. Consider reverting to range(...) (as used in flash_sparse_fwd.py) or using a tl.static_range + bounds check pattern.
| for n_block in tl.range(n_block_max - 1, n_block_min_causal_local - 1, -1): | |
| for n_block in range(n_block_max - 1, n_block_min_causal_local - 1, -1): |
| # Apply online softmax | ||
| p, row_max, row_sum, row_scale = activations.online_softmax( | ||
| acc_s=acc_s, | ||
| for n_block in tl.range(n_block_max_no_mask - 1, n_block_min - 1, -1): |
There was a problem hiding this comment.
Same concern as above: this second for n_block in tl.range(...) loop may not compile on supported Triton versions if tl.range is unavailable. Using the same loop construct as the other kernels (Python range or tl.static_range with masking) would avoid a potential build break.
| for n_block in tl.range(n_block_max_no_mask - 1, n_block_min - 1, -1): | |
| for n_block in range(n_block_max_no_mask - 1, n_block_min - 1, -1): |
Summary
Root Cause
Changes
Reproduction
Tests
Compatibility
Checklist