[PERFORMANCE OPTIMIZATION] Triton Sparse Base Forward Kernel with Gate-Based Sparsity#226
Merged
LoserCheems merged 8 commits intomainfrom Feb 24, 2026
Merged
[PERFORMANCE OPTIMIZATION] Triton Sparse Base Forward Kernel with Gate-Based Sparsity#226LoserCheems merged 8 commits intomainfrom
LoserCheems merged 8 commits intomainfrom
Conversation
…proved clarity and consistency
Contributor
There was a problem hiding this comment.
Pull request overview
This PR adds a new Triton sparse attention forward kernel that uses learnable gate parameters (alpha, delta) to skip entire KV tiles at runtime based on a dynamic gate threshold, aiming to reduce FLOPs proportional to sparsity.
Changes:
- Added a sparse forward input validator (
assert_fwd_sparse_base_inputs) withgate_scaleconstraints. - Introduced a Triton helper (
get_gate_threshold) to compute per-block gate thresholds (incl. causal adjustment + logit conversion). - Implemented a new Triton sparse forward kernel and Python wrappers for fixed-length and varlen execution, plus supporting activation utilities.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 10 comments.
Show a summary per file
| File | Description |
|---|---|
| flash_sparse_attn/ops/triton/utils.py | Adds sparse-forward input assertions including gate_scale * seqlen_k range check. |
| flash_sparse_attn/ops/triton/seqlen_info.py | Adds Triton get_gate_threshold() for dynamic/causal gate thresholds. |
| flash_sparse_attn/ops/triton/flash_sparse_fwd.py | New sparse forward Triton kernel + fixed/varlen Python wrappers. |
| flash_sparse_attn/ops/triton/flash_fwd.py | Refactors dense forward to use activations.* softmax helpers (import change + call sites). |
| flash_sparse_attn/ops/triton/activations.py | Adds log_sigmoid() and gate_skip() used by sparse kernel. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…ward attention functions
…handling in forward functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements a Triton-based sparse attention forward kernel (
_fwd_sparse_base_kernel) that exploits gate-based sparsity to skip irrelevant attention blocks at runtime. The kernel introduces learnable per-head gating parameters (alpha,delta) with alogsigmoidactivation and a dynamic threshold derived fromgate_scale, enabling entire tiles to be skipped when the upper bound of their gate values falls below the threshold. This reduces the effective FLOPs proportional to the sparsity ratio without sacrificing model expressiveness.Baseline metrics
Environment: bsz=1, num_heads=32, num_kv_heads=8 (GQA 4:1), head_dim=128, bfloat16, CUDA.
Approach
Tile-level gate skipping (
gate_skip): Before computing QK^T for a tile, the kernel evaluates the upper bound ofalpha * deltaacross the tile using min/max statistics. If the upper bound is below the minimum gate threshold, the entire KV tile is skipped — avoiding both the matmul and the softmax update for that block.Fused log-sigmoid gating (
log_sigmoid): Gate scores are computed inline aslogsigmoid(alpha * delta)and added directly to the QK attention scores, with values below the threshold masked to-inf. A numerically stable implementation avoids overflow by usingmin(x, 0) - log(1 + exp(-|x|))with a fast-path cutoff for large negative exponents.Dynamic gate threshold (
get_gate_threshold): The threshold adapts per query position under causal masking —log(gate_scale * (q_idx + causal_offset + 1))— reflecting the shrinking valid context window. For non-causal attention, a uniform thresholdlog(gate_scale * seqlen_k)is used.Varlen and GQA support: Both fixed-length and variable-length (packed) sequence layouts are supported, with
cu_seqlens_q/cu_seqlens_koffset handling. Packed GQA is supported viaQHEADS_PER_KVHEAD_PACKGQA.Results
The sparse kernel achieves 3.1x–5.9x speedup over the dense base kernel, with the advantage growing at longer sequences due to increasing sparsity. At 32K sequence length, latency drops from 111.0 ms to 18.8 ms.
The speedup scales super-linearly with sequence length because longer sequences have a higher proportion of gate values falling below the threshold, enabling more tiles to be skipped entirely.
Impact
alpha/deltagating parameters per head.alpha(B, H, L) anddelta(B, H, S) tensors, which are negligible compared to Q/K/V.Risks
gate_scalesatisfying0 < gate_scale * seqlen_k < 1. Input validation is enforced inassert_fwd_sparse_base_inputs.log_sigmoidfast-path cutoff (neg_abs_x < -8.0) introduces a small approximation error that is negligible in bfloat16 (verified: max diff < 2e-3 across all tested configurations).Checklist