Skip to content

[PERFORMANCE OPTIMIZATION] Triton Sparse Base Forward Kernel with Gate-Based Sparsity#226

Merged
LoserCheems merged 8 commits intomainfrom
optime-triton-kernels
Feb 24, 2026
Merged

[PERFORMANCE OPTIMIZATION] Triton Sparse Base Forward Kernel with Gate-Based Sparsity#226
LoserCheems merged 8 commits intomainfrom
optime-triton-kernels

Conversation

@LoserCheems
Copy link
Copy Markdown
Collaborator

Summary

Implements a Triton-based sparse attention forward kernel (_fwd_sparse_base_kernel) that exploits gate-based sparsity to skip irrelevant attention blocks at runtime. The kernel introduces learnable per-head gating parameters (alpha, delta) with a logsigmoid activation and a dynamic threshold derived from gate_scale, enabling entire tiles to be skipped when the upper bound of their gate values falls below the threshold. This reduces the effective FLOPs proportional to the sparsity ratio without sacrificing model expressiveness.

Baseline metrics

seqlen Flash Attn Base (ms) Flash Sparse Attn Base (ms) Speedup
1024 0.162 0.053 3.06x
2048 0.528 0.124 4.25x
4096 1.920 0.393 4.89x
8192 7.224 1.422 5.08x
16384 27.751 5.016 5.53x
32768 111.028 18.841 5.89x

Environment: bsz=1, num_heads=32, num_kv_heads=8 (GQA 4:1), head_dim=128, bfloat16, CUDA.

Approach

  1. Tile-level gate skipping (gate_skip): Before computing QK^T for a tile, the kernel evaluates the upper bound of alpha * delta across the tile using min/max statistics. If the upper bound is below the minimum gate threshold, the entire KV tile is skipped — avoiding both the matmul and the softmax update for that block.

  2. Fused log-sigmoid gating (log_sigmoid): Gate scores are computed inline as logsigmoid(alpha * delta) and added directly to the QK attention scores, with values below the threshold masked to -inf. A numerically stable implementation avoids overflow by using min(x, 0) - log(1 + exp(-|x|)) with a fast-path cutoff for large negative exponents.

  3. Dynamic gate threshold (get_gate_threshold): The threshold adapts per query position under causal masking — log(gate_scale * (q_idx + causal_offset + 1)) — reflecting the shrinking valid context window. For non-causal attention, a uniform threshold log(gate_scale * seqlen_k) is used.

  4. Varlen and GQA support: Both fixed-length and variable-length (packed) sequence layouts are supported, with cu_seqlens_q/cu_seqlens_k offset handling. Packed GQA is supported via QHEADS_PER_KVHEAD_PACKGQA.

Results

The sparse kernel achieves 3.1x–5.9x speedup over the dense base kernel, with the advantage growing at longer sequences due to increasing sparsity. At 32K sequence length, latency drops from 111.0 ms to 18.8 ms.

The speedup scales super-linearly with sequence length because longer sequences have a higher proportion of gate values falling below the threshold, enabling more tiles to be skipped entirely.

Impact

  • Throughput: Up to 5.9x forward pass speedup at long context lengths with no architectural changes required — only the addition of lightweight alpha/delta gating parameters per head.
  • Memory: Minimal overhead from the alpha (B, H, L) and delta (B, H, S) tensors, which are negligible compared to Q/K/V.
  • Scaling: The benefit increases with sequence length, making this particularly valuable for long-context LLM inference and training.

Risks

  • Gate threshold correctness depends on gate_scale satisfying 0 < gate_scale * seqlen_k < 1. Input validation is enforced in assert_fwd_sparse_base_inputs.
  • The log_sigmoid fast-path cutoff (neg_abs_x < -8.0) introduces a small approximation error that is negligible in bfloat16 (verified: max diff < 2e-3 across all tested configurations).
  • Autotuning tile sizes may produce different optimal configurations than the base kernel; the same autotune config space is reused.

Checklist

Copilot AI review requested due to automatic review settings February 24, 2026 05:20
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a new Triton sparse attention forward kernel that uses learnable gate parameters (alpha, delta) to skip entire KV tiles at runtime based on a dynamic gate threshold, aiming to reduce FLOPs proportional to sparsity.

Changes:

  • Added a sparse forward input validator (assert_fwd_sparse_base_inputs) with gate_scale constraints.
  • Introduced a Triton helper (get_gate_threshold) to compute per-block gate thresholds (incl. causal adjustment + logit conversion).
  • Implemented a new Triton sparse forward kernel and Python wrappers for fixed-length and varlen execution, plus supporting activation utilities.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 10 comments.

Show a summary per file
File Description
flash_sparse_attn/ops/triton/utils.py Adds sparse-forward input assertions including gate_scale * seqlen_k range check.
flash_sparse_attn/ops/triton/seqlen_info.py Adds Triton get_gate_threshold() for dynamic/causal gate thresholds.
flash_sparse_attn/ops/triton/flash_sparse_fwd.py New sparse forward Triton kernel + fixed/varlen Python wrappers.
flash_sparse_attn/ops/triton/flash_fwd.py Refactors dense forward to use activations.* softmax helpers (import change + call sites).
flash_sparse_attn/ops/triton/activations.py Adds log_sigmoid() and gate_skip() used by sparse kernel.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@LoserCheems LoserCheems merged commit 02cf163 into main Feb 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants