Skip to content

Refactor masking logic in backward kernel functions#255

Merged
LoserCheems merged 1 commit intomainfrom
optim_triton_version
Mar 21, 2026
Merged

Refactor masking logic in backward kernel functions#255
LoserCheems merged 1 commit intomainfrom
optim_triton_version

Conversation

@LoserCheems
Copy link
Collaborator

The refactor improves clarity and efficiency of the masking logic in backward kernel functions. It consolidates the mask application to reduce redundancy and enhance maintainability. No bugs were fixed, but the code structure is now cleaner and more efficient.

Copilot AI review requested due to automatic review settings March 20, 2026 17:53
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Refactors masking placement in Triton backward inner kernels so masking is applied immediately after score computation rather than later in the softmax/gradient path. This targets clearer control flow and aligns masking with downstream max/threshold computations used for block skipping.

Changes:

  • Move mask.apply_mask(...) to directly follow acc_s computation in sparse backward inner kernel.
  • Move mask.apply_mask(...) to directly follow acc_s computation in gated backward inner kernel.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
flash_sparse_attn/ops/triton/flash_sparse_bwd.py Applies the attention mask earlier in _bwd_inner_sparse_base_kernel, before block-max/skip logic.
flash_sparse_attn/ops/triton/flash_gated_bwd.py Applies the attention mask earlier in _bwd_inner_gated_base_kernel, before block-max/skip logic.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +55 to +59
if IS_MASK:
# Apply mask
acc_s = mask.apply_mask(
acc_s=acc_s,
m_block=m_block,
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving mask.apply_mask ahead of block_max_curr = tl.max(acc_s) changes the semantics of the block-level skip_softmax decision and block_max updates (masked-out scores can no longer influence the running max). This can change which blocks are skipped and potentially the resulting gradients, so the PR description (“No bugs were fixed…”) may be inaccurate; please confirm this behavioral change is intended and update the description / add a correctness note or test coverage accordingly.

Copilot uses AI. Check for mistakes.
Comment on lines +100 to +105
if IS_MASK:
# Apply mask
acc_s = mask.apply_mask(
acc_s=acc_s,
m_block=m_block,
n_block=n_block,
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in the sparse kernel, applying the mask before computing block_max_curr changes the skip_softmax gating behavior and running block_max for gated backward. This can affect which blocks run the softmax path and therefore can change outputs; please confirm this is intended and reflect it in the PR description / add a targeted correctness check.

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 552bdb2 into main Mar 21, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants