-
Notifications
You must be signed in to change notification settings - Fork 49
[PERFORMANCE OPTIMIZATION] Flash Sparse Attention #221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Favours bias-based activation checks, deleting the mask tensors and related strides to cut branching in forward/backward Triton kernels. Unifies batch-head indexing, renames the custom autograd wrapper, and normalizes bias gradient accumulation so the kernels only reason about optional bias inputs.
Ensures the log-sum-exp pivot falls back to zero whenever the max accumulator is -inf so the probability exponentiation no longer relies on potentially invalid lse values, preventing NaN outputs in sparse attention
- Removed multiple instantiation files for various configurations of the `run_mha_fwd_splitkv_dispatch` template, specifically for half-precision (fp16) and bfloat16 (bf16) types across different head dimensions (32, 64, 96). - Consolidated template instantiations to reduce code duplication and improve compilation speed. - Updated the remaining instantiation files to reflect the changes in template parameters, ensuring compatibility with the new structure. - Ensured that all changes maintain the intended functionality while streamlining the codebase.
Removes the mask argument and treats bias as the sole auxiliary tensor so the forward/backward CUDA hooks share a single contiguous path. Drops the unused varlen wrappers and renames the autograd shim to reflect its sparse-attention scope, keeping gradient bookkeeping for bias consistent with the new shape contract.
Standardizes constructor and helper signatures for improved readability and consistent naming in the BlockInfo helper
Simplifies forward/backward params by folding bias fields into the core QKV struct, removing separate mask/bias state, and extending stride/head metadata to cover bias tensors. Aligns the dispatch interfaces to only specialize on type, head dim, and causality so future kernels share a single parameter surface.
Updates fwd/bwd kernel traits to treat mask/bias storage as a unified BSP buffer, removing separate toggles and types, so shared memory sizing and GMEM copies stay consistent with the new layout scheme
Simplifies the sparse attention mask helper by eliminating the optional external mask path, leaving only causal bounds checks so kernels no longer carry unused template parameters or branching.
Improves readability by expanding template declarations, arguments, and loops onto separate lines and aligning style with project conventions
Standardizes whitespace, braces, and argument wrapping across flash sparse attention helpers for improved readability and consistency. Adds a shared or_reduce utility to mirror existing mask reduction logic and prepares future call sites.
Eliminates the mask/bias template plumbing so the forward kernels only manage bias-like tensors through a unified shared-memory copy/reduction path, reducing predicates and redundant copies while keeping causal masking logic intact.
Removes mask/bias template permutations and unifies kernel macros to shrink the instantiation matrix and improve readability. Standardizes launch heuristics, shared-memory attributes, and combine-kernel selection so future head-dim additions stay manageable.
Improves readability of the backward preprocessing helpers by reformatting templates, conditionals, and tensor constructions for consistency. Aligns the dot(do,o) launch with the correct thread-partitioning trait to match the intended memory layout assumptions.
Replaces mask/bias-specific plumbing with generic B/DB tensors so backward attention reuses the BSP memory layouts, copy paths, and predicates. Updates template parameters, shared-memory tiling, and accumulation logic to match the new interfaces and keep dq/dk/dv computation consistent when bias rows are accumulated.
Removes mask and bias template booleans from backward launch helpers so fewer instantiations are needed while keeping causal specialization intact. Refactors macros, kernel invocations, and shared-memory configuration calls for clearer formatting and centralized unsupported-arch handling.
Removes mask and bias permutations so generated kernels match the leaner launch signatures and avoid unused instantiations
Unifies the forward and backward APIs around a single bias tensor, dropping mask-specific plumbing, stride bookkeeping, and template switches to reduce kernel variants and enforce consistent shape checks. Cleans up split-K heuristics and parameter naming while removing the unused variable-length entry points and their Python bindings.
Aligns the module and class names with the flash sparse attention implementation and drops the redundant mask creation so the operator relies on bias-based masking
Updates the backward kernel to use the renamed layout and QKV threading traits so it matches the latest Kernel_traits contract. Drops the unused mask alias to avoid stale references from previous definitions.
Ensures row maxima start at negative infinity and sums at zero so subsequent softmax passes receive deterministic initial values
Introduces a dedicated Triton kernel routine that applies padding and causal masks so the softmax remains numerically correct for uneven blocks and causal attention spans.
Introduces Triton kernels for streaming softmax that track row maxima, sums, and scaling, stabilizing accumulation through optional -inf checks and final rescaling.
Removes the first-block flag so row max and sum updates reuse one code path, relying on sentinel init values instead Clarifies docs and clamps -inf row maxima when checking for invalid values, preventing NaN scaling on early tiles
Introduces jit utilities to compute min/max block ranges for sparse attention, covering causal/local masking and Packed GQA constraints to support upcoming Triton kernels.
Replaces Triton intrinsics with Python min/max so helper utilities can run outside kernels and avoid device-only calls
Introduces a Triton-based forward flash attention kernel with autotuning, masking, and online softmax to support efficient CUDA execution of grouped query attention
Describes mask parameters using reStructuredText-style :param tags to align documentation with project conventions
Aligns docstrings with Sphinx-style annotations for clearer API documentation
Introduces reusable Triton helpers for computing per-batch offsets, sequence lengths, and padded positions so attention kernels can handle dynamic Q/K layouts with or without cumulative lengths.
Adds cu_seqlens plumbing and seqlen-aware offsets so the forward kernel handles per-batch sequence lengths without padding overhead. Updates stride calculations and LSE/output allocation to cover ragged inputs while keeping the static-shape path untouched.
Passes optional window bounds from the Python wrapper down to the Triton kernel so forward kernels can reuse the global path for local attention windows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors the flash sparse attention implementation by simplifying the template parameter structure, removing separate mask and bias flags in favor of a unified bias parameter handling approach. The changes eliminate redundant template instantiations and consolidate multiple parameter structures into a single unified structure.
Changes:
- Removed
HAS_MASKandHAS_BIAStemplate parameters, reducing template instantiations from 8 combinations to 2 (causal/non-causal only) - Merged
Mask_paramsandBias_paramsintoQKVB_paramsstructure with unified bias handling - Updated kernel generation script to reflect simplified template structure
Reviewed changes
Copilot reviewed 299 out of 321 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
csrc/flash_sparse_attn/src/flash.h |
Consolidated parameter structures and simplified template signatures |
csrc/flash_sparse_attn/src/generate_kernels.py |
Updated kernel generation to remove mask/bias template parameters |
csrc/flash_sparse_attn/src/hardware_info.h |
Reformatted function calls for improved readability |
csrc/flash_sparse_attn/src/instantiations/*.cu |
Removed files with mask/bias combinations and updated remaining files with simplified template parameters |
CITATION.cff |
Updated project title |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Summary
This PR migrates Dynamic Mask Attention to the v2 algorithm to avoid quadratic memory growth with sequence length. The old approach materialized per‑query masks/tensors, while the new formulation uses a sparse, factorized representation that keeps memory linear in sequence length and head/window sizes. This preserves support for arbitrary masks while significantly reducing memory pressure.
Baseline metrics$O(L_q \cdot L_k)$ auxiliary memory.
N/A (baseline measurements to be filled by the maintainer based on their target hardware).
Baseline scenario: old DM attention with explicit mask materialization has
Approach
Results$O(B \cdot H \cdot (L_q + L_k + W))$ instead of $O(B \cdot H \cdot L_q \cdot L_k)$ , where $W$ is window size.
Memory footprint scales approximately
Impact
Risks
Checklist