Refactor masking logic in backward kernel functions#255
Conversation
There was a problem hiding this comment.
Pull request overview
Refactors masking placement in Triton backward inner kernels so masking is applied immediately after score computation rather than later in the softmax/gradient path. This targets clearer control flow and aligns masking with downstream max/threshold computations used for block skipping.
Changes:
- Move
mask.apply_mask(...)to directly followacc_scomputation in sparse backward inner kernel. - Move
mask.apply_mask(...)to directly followacc_scomputation in gated backward inner kernel.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
flash_sparse_attn/ops/triton/flash_sparse_bwd.py |
Applies the attention mask earlier in _bwd_inner_sparse_base_kernel, before block-max/skip logic. |
flash_sparse_attn/ops/triton/flash_gated_bwd.py |
Applies the attention mask earlier in _bwd_inner_gated_base_kernel, before block-max/skip logic. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if IS_MASK: | ||
| # Apply mask | ||
| acc_s = mask.apply_mask( | ||
| acc_s=acc_s, | ||
| m_block=m_block, |
There was a problem hiding this comment.
Moving mask.apply_mask ahead of block_max_curr = tl.max(acc_s) changes the semantics of the block-level skip_softmax decision and block_max updates (masked-out scores can no longer influence the running max). This can change which blocks are skipped and potentially the resulting gradients, so the PR description (“No bugs were fixed…”) may be inaccurate; please confirm this behavioral change is intended and update the description / add a correctness note or test coverage accordingly.
| if IS_MASK: | ||
| # Apply mask | ||
| acc_s = mask.apply_mask( | ||
| acc_s=acc_s, | ||
| m_block=m_block, | ||
| n_block=n_block, |
There was a problem hiding this comment.
As in the sparse kernel, applying the mask before computing block_max_curr changes the skip_softmax gating behavior and running block_max for gated backward. This can affect which blocks run the softmax path and therefore can change outputs; please confirm this is intended and reflect it in the PR description / add a targeted correctness check.
The refactor improves clarity and efficiency of the masking logic in backward kernel functions. It consolidates the mask application to reduce redundancy and enhance maintainability. No bugs were fixed, but the code structure is now cleaner and more efficient.