Skip to content

Improves softmax stability with log2 scaling#228

Merged
LoserCheems merged 5 commits intomainfrom
optim-triton-version
Mar 1, 2026
Merged

Improves softmax stability with log2 scaling#228
LoserCheems merged 5 commits intomainfrom
optim-triton-version

Conversation

@LoserCheems
Copy link
Collaborator

Summary

  • This update enhances the numerical stability of the softmax function.

Root Cause

  • The previous implementation suffered from instability due to the handling of exponentials.

Changes

  • Refactored the softmax function to utilize log2 scaling, improving the computation of exponentials and the log-sum-exp operation.

Reproduction

  • The issue can be reproduced by testing the softmax function with extreme input values.

Tests

  • Added tests to validate the stability of the new softmax implementation under various conditions.

Compatibility

  • No backward compatibility issues are expected.

Checklist

Uses log2-based scaling to align exponentials and
log-sum-exp, improving numerical stability.
Improves reuse by extracting the inner loop and aligning softmax scaling with log2 math to simplify masking paths and reduce duplicated logic
Copilot AI review requested due to automatic review settings March 1, 2026 03:59
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the Triton flash-attention forward softmax path to use base-2 exponentiation (exp2) with a log2(e) scaling factor, aiming to improve numerical stability for extreme logits while simplifying inner-loop code reuse.

Changes:

  • Introduces a shared _fwd_inner_base_kernel() helper in the Triton forward kernel and switches kernel inputs from softmax_scale to softmax_scale_log2.
  • Updates activations.online_softmax() to compute probabilities/rescaling via tl.math.exp2(... * scale_log2).
  • Updates activations.finalize() to accept scale_log2 and adjust LSE computation accordingly.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
flash_sparse_attn/ops/triton/flash_fwd.py Refactors the forward kernel loop body into a helper and passes softmax_scale_log2 into the online softmax/finalize path.
flash_sparse_attn/ops/triton/activations.py Switches online softmax to exp2 with a new scale_log2 parameter and modifies finalize/LSE computation.
Comments suppressed due to low confidence (1)

flash_sparse_attn/ops/triton/activations.py:25

  • online_softmax() now requires a scale_log2 argument, but its docstring still documents only CHECK_INF and doesn't explain the expected units (e.g., softmax_scale * log2(e)) or how it relates to the returned row_max/row_sum. Please update the docstring to include scale_log2 so callers can pass the correct value.
def online_softmax(
    acc_s,
    row_max,
    row_sum,
    scale_log2,
    CHECK_INF: tl.constexpr,
):
    """
    Apply online softmax to acc_s, and update row_max and row_sum.

    :param acc_s: Attention scores tensor of shape [BLOCK_M, BLOCK_N].
    :param row_max: Current maximum values per row of shape [BLOCK_M], init to -inf.
    :param row_sum: Current sum values per row of shape [BLOCK_M], init to 0.
    :param CHECK_INF: Boolean flag indicating if -inf row_max should be clamped to 0.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +71 to +77
ln2 = tl.math.log2(2.0)
scale = scale_log2 * ln2
lse = tl.where(
acc_o_is_zero_or_nan,
float("-inf"),
row_max * scale + tl.log(row_sum),
)
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

finalize() computes ln2 = tl.math.log2(2.0) (which is 1), so scale = scale_log2 * ln2 leaves scale == scale_log2. Since row_sum is later logged with tl.log (natural log), the row_max term must be multiplied by the natural-exp scale (softmax_scale), i.e. scale_log2 * ln(2), not scale_log2. As-is, lse will be systematically incorrect and will break downstream logic that expects natural-log LSE (e.g., split-KV combine).

Copilot uses AI. Check for mistakes.
Simplifies exp2 usage and adjusts log scaling to improve numerical correctness in output normalization and log-sum-exp computation
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

Comments suppressed due to low confidence (2)

flash_sparse_attn/ops/triton/activations.py:26

  • The online_softmax docstring/param docs don’t mention the newly added scale_log2 parameter, which makes the API change harder to understand for future maintainers. Please document what scale_log2 represents (e.g., softmax_scale * log2(e)) and its expected dtype/range.
    """
    Apply online softmax to acc_s, and update row_max and row_sum.

    :param acc_s: Attention scores tensor of shape [BLOCK_M, BLOCK_N].
    :param row_max: Current maximum values per row of shape [BLOCK_M], init to -inf.
    :param row_sum: Current sum values per row of shape [BLOCK_M], init to 0.
    :param CHECK_INF: Boolean flag indicating if -inf row_max should be clamped to 0.

flash_sparse_attn/ops/triton/activations.py:63

  • The finalize docstring doesn’t describe the new scale_log2 parameter, even though it’s now required. Please update the docstring to include what scale_log2 is and why it’s needed for computing lse.
    """
    Finalize online softmax by computing output scale and logsumexp.

    :param row_max: Final maximum values per row of shape [BLOCK_M].
    :param row_sum: Final sum values per row of shape [BLOCK_M].
    :param final_scale: Scaling factor to be applied to the output.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

acc_s,
row_max,
row_sum,
scale_log2,
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

activations.online_softmax now requires a new scale_log2 argument, but there are existing call sites (e.g. in flash_sparse_fwd.py) that still call it without this parameter. This will break compilation/runtime for the Triton sparse forward path unless those call sites are updated or scale_log2 is given a backward-compatible default.

Suggested change
scale_log2,
scale_log2=1.0,

Copilot uses AI. Check for mistakes.
Comment on lines 51 to 56
def finalize(
row_max,
row_sum,
scale_log2,
final_scale,
):
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

activations.finalize now requires scale_log2, but there are existing call sites (e.g. in flash_sparse_fwd.py) that still call finalize(row_max, row_sum, final_scale=...) without the new argument. Please update those call sites or keep finalize backward compatible, otherwise the sparse Triton forward will fail.

Copilot uses AI. Check for mistakes.
# Apply mask
acc_s = mask.apply_mask(
acc_s=acc_s,
for n_block in tl.range(n_block_max - 1, n_block_min_causal_local - 1, -1):
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl.range(...) is used for the n_block loops here, but the rest of the Triton kernels in this repo use Python range(...) / tl.static_range(...). triton.language does not consistently provide a range helper across Triton versions; if tl.range isn't available in the project's supported Triton/PyTorch-Triton versions, this will fail at compile time. Consider reverting to range(...) (as used in flash_sparse_fwd.py) or using a tl.static_range + bounds check pattern.

Suggested change
for n_block in tl.range(n_block_max - 1, n_block_min_causal_local - 1, -1):
for n_block in range(n_block_max - 1, n_block_min_causal_local - 1, -1):

Copilot uses AI. Check for mistakes.
# Apply online softmax
p, row_max, row_sum, row_scale = activations.online_softmax(
acc_s=acc_s,
for n_block in tl.range(n_block_max_no_mask - 1, n_block_min - 1, -1):
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same concern as above: this second for n_block in tl.range(...) loop may not compile on supported Triton versions if tl.range is unavailable. Using the same loop construct as the other kernels (Python range or tl.static_range with masking) would avoid a potential build break.

Suggested change
for n_block in tl.range(n_block_max_no_mask - 1, n_block_min - 1, -1):
for n_block in range(n_block_max_no_mask - 1, n_block_min - 1, -1):

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 39deca0 into main Mar 1, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants