[FEATURE SUPPORT] Add Triton backward support#235
Merged
LoserCheems merged 12 commits intomainfrom Mar 9, 2026
Merged
Conversation
…functions for improved block calculations in causal and local contexts
…re correct tensor operations
…_kernel for clarity
Contributor
There was a problem hiding this comment.
Pull request overview
This PR aims to address correctness issues in Triton attention tiling/masking by introducing new helper utilities for block boundary computations and pointer construction, and by adjusting masking index computation for the SWAP_AB (swapped Q/K) case.
Changes:
- Added a generic Triton JIT pointer-construction helper (
make_ptrs) inseqlen_info.py. - Added new Triton JIT helpers in
block_info.pyfor m-block boundary computations under causal/local constraints. - Fixed index assignments in
mask.apply_maskfor theSWAP_ABbranch; adjusted comments inflash_fwd.py.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
flash_sparse_attn/ops/triton/seqlen_info.py |
Adds make_ptrs helper for pointer creation (currently no call sites in repo). |
flash_sparse_attn/ops/triton/mask.py |
Adjusts q_idx/k_idx construction for SWAP_AB masking. |
flash_sparse_attn/ops/triton/flash_fwd.py |
Comment-only refactor around pointer creation sections. |
flash_sparse_attn/ops/triton/block_info.py |
Adds new m-block boundary helpers for causal/local scheduling (currently no call sites in repo). |
Comments suppressed due to low confidence (1)
flash_sparse_attn/ops/triton/flash_fwd.py:265
- The PR description says new tests were added and existing tests validated, but this diff doesn’t include any test changes/additions. Please either add the corresponding tests in this PR or update the description/checklist to reflect what was actually changed.
# Create pointers
if not PACK_GQA:
lse_ptrs = tl.make_block_ptr(
base=lse_base,
shape=(actual_seqlen_q,),
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…uration based on device architecture
…in launch_grid.py
…bwd_postprocess.py
Contributor
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 10 out of 10 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
… launch_template.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR introduces end-to-end backward support for the Triton flash attention path, including:
The goal is to make backward computation available in the same Triton stack as forward, with architecture-aware launch behavior and varlen-compatible memory layout.
Design
The implementation follows the same staged design as the reference cute pipeline:
Key design choices:
Alternatives considered:
Changes
New/updated functionality includes:
Public behavior:
Implementation notes
Tests
Validation completed:
Pending runtime validation:
Documentation
Checklist