Skip to content

[FEATURE SUPPORT] Add Triton backward support#235

Merged
LoserCheems merged 12 commits intomainfrom
optim-triton-version
Mar 9, 2026
Merged

[FEATURE SUPPORT] Add Triton backward support#235
LoserCheems merged 12 commits intomainfrom
optim-triton-version

Conversation

@LoserCheems
Copy link
Collaborator

@LoserCheems LoserCheems commented Mar 8, 2026

Summary

This PR introduces end-to-end backward support for the Triton flash attention path, including:

  • Backward launch configuration selection by GPU architecture.
  • Backward grid helpers for main, preprocess, and postprocess kernels.
  • A backward preprocess kernel to compute dPsum, convert LSE to log2 space, and initialize dQ accumulation buffer.
  • A backward core kernel to compute dQ, dK, and dV with support for causal/local masking, varlen inputs, and GQA accumulation behavior.
  • A backward postprocess kernel to scale and cast accumulated dQ to output dtype.

The goal is to make backward computation available in the same Triton stack as forward, with architecture-aware launch behavior and varlen-compatible memory layout.

Design

The implementation follows the same staged design as the reference cute pipeline:

  • Stage 1 (preprocess): prepare numerically stable intermediate tensors and reset dQ accumulation.
  • Stage 2 (main backward): iterate over N blocks, compute attention-gradient math, atomically accumulate dQ tiles, and produce dK/dV accumulators.
  • Stage 3 (postprocess): apply scale and cast dQ accumulation into final gradient tensor.

Key design choices:

  • Keep backward orchestration modular across three files for easier tuning and debugging.
  • Reuse shared seqlen/padded-offset helpers for fixed-length and varlen consistency.
  • Use architecture-based launch templates to avoid hardcoding one-size-fits-all kernel configs.
  • Use float32 accumulators for numerical stability in intermediate reductions.

Alternatives considered:

  • Monolithic backward kernel without preprocess/postprocess splitting was avoided due to lower maintainability and harder numerical/debug control.
  • Device-agnostic fixed launch config was avoided due to occupancy/performance risk across Ampere/Hopper/Blackwell classes.

Changes

New/updated functionality includes:

  • Added backward launch config API in launch templates.
  • Added backward grid builders for main, preprocess, and postprocess kernels.
  • Added Triton backward preprocess implementation.
  • Added Triton backward postprocess implementation.
  • Added Triton backward main kernel and Python entrypoints for:
  1. Fixed-length backward.
  2. Varlen backward with cu_seqlens and optional seqused.

Public behavior:

  • Backward path now exists in Triton backend and returns dq, dk, dv outputs for both fixed and varlen use cases.
  • Varlen API expects max sequence lengths for launch sizing.

Implementation notes

  • dQ is accumulated in float32 buffer and finalized in postprocess to improve stability.
  • dK is scaled by softmax_scale before final write, consistent with backward derivation.
  • GQA path uses atomic accumulation for dK/dV when multiple Q heads map to one KV head.
  • Current varlen path depends on provided max_seqlen_q/max_seqlen_k for grid sizing.
  • Follow-up hardening recommended: add automatic fallback inference of max sequence lengths when they are not explicitly provided.

Tests

Validation completed:

  • Reference parity review against cute backward staging and dataflow:
  1. preprocess semantic parity (dPsum/LSELog2/dQ init).
  2. core backward math flow parity (p, ds, dq/dk/dv accumulation).
  3. postprocess scaling/cast parity for dQ finalization.
  • Static consistency checks:
  1. fixed-length stride and shape mapping.
  2. varlen padded offset usage for intermediate buffers.
  3. mask/no-mask block range splitting.

Pending runtime validation:

  • Multi-arch performance sanity checks (A100/H100/B200-class paths).

Documentation

  • Inline code comments added in kernels for major computation stages.
  • Recommended follow-up: add a short backward architecture section to developer docs describing:
  1. preprocess/main/postprocess pipeline.
  2. varlen max sequence length requirements.
  3. architecture launch config rationale.

Checklist

Copilot AI review requested due to automatic review settings March 8, 2026 15:11
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to address correctness issues in Triton attention tiling/masking by introducing new helper utilities for block boundary computations and pointer construction, and by adjusting masking index computation for the SWAP_AB (swapped Q/K) case.

Changes:

  • Added a generic Triton JIT pointer-construction helper (make_ptrs) in seqlen_info.py.
  • Added new Triton JIT helpers in block_info.py for m-block boundary computations under causal/local constraints.
  • Fixed index assignments in mask.apply_mask for the SWAP_AB branch; adjusted comments in flash_fwd.py.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
flash_sparse_attn/ops/triton/seqlen_info.py Adds make_ptrs helper for pointer creation (currently no call sites in repo).
flash_sparse_attn/ops/triton/mask.py Adjusts q_idx/k_idx construction for SWAP_AB masking.
flash_sparse_attn/ops/triton/flash_fwd.py Comment-only refactor around pointer creation sections.
flash_sparse_attn/ops/triton/block_info.py Adds new m-block boundary helpers for causal/local scheduling (currently no call sites in repo).
Comments suppressed due to low confidence (1)

flash_sparse_attn/ops/triton/flash_fwd.py:265

  • The PR description says new tests were added and existing tests validated, but this diff doesn’t include any test changes/additions. Please either add the corresponding tests in this PR or update the description/checklist to reflect what was actually changed.
    # Create pointers
    if not PACK_GQA:
        lse_ptrs = tl.make_block_ptr(
            base=lse_base,
            shape=(actual_seqlen_q,),

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@LoserCheems LoserCheems changed the title [BUG FIX] Improve block calculations and tensor operations [FEATURE SUPPORT] Add Triton backward support Mar 9, 2026
@LoserCheems LoserCheems requested a review from Copilot March 9, 2026 08:01
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@LoserCheems LoserCheems merged commit 3491059 into main Mar 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants