Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Summary
This PR migrates Dynamic Mask Attention to the v2 algorithm to avoid quadratic memory growth with sequence length. The old approach materialized per‑query masks/tensors, while the new formulation uses a sparse, factorized representation that keeps memory linear in sequence length and head/window sizes. This preserves support for arbitrary masks while significantly reducing memory pressure.

Baseline metrics
N/A (baseline measurements to be filled by the maintainer based on their target hardware).
Baseline scenario: old DM attention with explicit mask materialization has $O(L_q \cdot L_k)$ auxiliary memory.

Approach

  • Replace the old DM attention flow with the v2 sparse attention algorithm that factorizes the dynamic mask into components and avoids dense mask storage.
  • Compute attention with a block‑sparse formulation and on‑the‑fly masking, eliminating quadratic auxiliary tensors.
  • Preserve arbitrary mask semantics while reducing memory bandwidth and enabling larger sequence lengths.

Results
Memory footprint scales approximately $O(B \cdot H \cdot (L_q + L_k + W))$ instead of $O(B \cdot H \cdot L_q \cdot L_k)$, where $W$ is window size.

Impact

  • Memory: Large reduction from quadratic to linear‑ish auxiliary memory, enabling longer sequences.
  • Performance: Likely improved for long sequences due to reduced memory traffic; kernel launch overhead may increase slightly due to additional preprocessing.
  • Compatibility: Maintains arbitrary mask support.

Risks

  • Edge cases with very small sequence lengths or uncommon head dimensions may see reduced benefit.
  • Numerical behavior should be verified against prior implementation for equivalence.

Checklist

Favours bias-based activation checks, deleting the mask tensors and related strides to cut branching in forward/backward Triton kernels.
Unifies batch-head indexing, renames the custom autograd wrapper, and normalizes bias gradient accumulation so the kernels only reason about optional bias inputs.
Ensures the log-sum-exp pivot falls back to zero whenever the max accumulator is -inf so the probability exponentiation no longer relies on potentially invalid lse values, preventing NaN outputs in sparse attention
- Removed multiple instantiation files for various configurations of the `run_mha_fwd_splitkv_dispatch` template, specifically for half-precision (fp16) and bfloat16 (bf16) types across different head dimensions (32, 64, 96).
- Consolidated template instantiations to reduce code duplication and improve compilation speed.
- Updated the remaining instantiation files to reflect the changes in template parameters, ensuring compatibility with the new structure.
- Ensured that all changes maintain the intended functionality while streamlining the codebase.
Removes the mask argument and treats bias as the sole auxiliary tensor so the forward/backward CUDA hooks share a single contiguous path.
Drops the unused varlen wrappers and renames the autograd shim to reflect its sparse-attention scope, keeping gradient bookkeeping for bias consistent with the new shape contract.
Standardizes constructor and helper signatures for improved readability and consistent naming in the BlockInfo helper
Simplifies forward/backward params by folding bias fields into the core QKV struct, removing separate mask/bias state, and extending stride/head metadata to cover bias tensors.
Aligns the dispatch interfaces to only specialize on type, head dim, and causality so future kernels share a single parameter surface.
Updates fwd/bwd kernel traits to treat mask/bias storage as a unified BSP buffer, removing separate toggles and types, so shared memory sizing and GMEM copies stay consistent with the new layout scheme
Simplifies the sparse attention mask helper by eliminating the optional external mask path, leaving only causal bounds checks so kernels no longer carry unused template parameters or branching.
Improves readability by expanding template declarations, arguments, and loops onto separate lines and aligning style with project conventions
Standardizes whitespace, braces, and argument wrapping across flash sparse attention helpers for improved readability and consistency.
Adds a shared or_reduce utility to mirror existing mask reduction logic and prepares future call sites.
Eliminates the mask/bias template plumbing so the forward kernels only manage bias-like tensors through a unified shared-memory copy/reduction path, reducing predicates and redundant copies while keeping causal masking logic intact.
Removes mask/bias template permutations and unifies kernel macros to shrink the instantiation matrix and improve readability.
Standardizes launch heuristics, shared-memory attributes, and combine-kernel selection so future head-dim additions stay manageable.
Improves readability of the backward preprocessing helpers by reformatting templates, conditionals, and tensor constructions for consistency. Aligns the dot(do,o) launch with the correct thread-partitioning trait to match the intended memory layout assumptions.
Replaces mask/bias-specific plumbing with generic B/DB tensors so backward attention reuses the BSP memory layouts, copy paths, and predicates. Updates template parameters, shared-memory tiling, and accumulation logic to match the new interfaces and keep dq/dk/dv computation consistent when bias rows are accumulated.
Removes mask and bias template booleans from backward launch helpers so fewer instantiations are needed while keeping causal specialization intact.

Refactors macros, kernel invocations, and shared-memory configuration calls for clearer formatting and centralized unsupported-arch handling.
Removes mask and bias permutations so generated kernels match the leaner launch signatures and avoid unused instantiations
Unifies the forward and backward APIs around a single bias tensor, dropping mask-specific plumbing, stride bookkeeping, and template switches to reduce kernel variants and enforce consistent shape checks. Cleans up split-K heuristics and parameter naming while removing the unused variable-length entry points and their Python bindings.
Aligns the module and class names with the flash sparse attention implementation and drops the redundant mask creation so the operator relies on bias-based masking
Updates the backward kernel to use the renamed layout and QKV threading traits so it matches the latest Kernel_traits contract.
Drops the unused mask alias to avoid stale references from previous definitions.
Ensures row maxima start at negative infinity and sums at zero so subsequent softmax passes receive deterministic initial values
Introduces a dedicated Triton kernel routine that applies padding and causal masks so the softmax remains numerically correct for uneven blocks and causal attention spans.
Introduces Triton kernels for streaming softmax that track row maxima, sums, and scaling, stabilizing accumulation through optional -inf checks and final rescaling.
Removes the first-block flag so row max and sum updates reuse one code path, relying on sentinel init values instead
Clarifies docs and clamps -inf row maxima when checking for invalid values, preventing NaN scaling on early tiles
Introduces jit utilities to compute min/max block ranges for sparse attention, covering causal/local masking and Packed GQA constraints to support upcoming Triton kernels.
Replaces Triton intrinsics with Python min/max so helper utilities can run outside kernels and avoid device-only calls
Introduces a Triton-based forward flash attention kernel with autotuning, masking, and online softmax to support efficient CUDA execution of grouped query attention
Describes mask parameters using reStructuredText-style :param tags to align documentation with project conventions
Aligns docstrings with Sphinx-style annotations for clearer API documentation
Introduces reusable Triton helpers for computing per-batch offsets, sequence lengths, and padded positions so attention kernels can handle dynamic Q/K layouts with or without cumulative lengths.
Adds cu_seqlens plumbing and seqlen-aware offsets so the forward kernel handles per-batch sequence lengths without padding overhead.
Updates stride calculations and LSE/output allocation to cover ragged inputs while keeping the static-shape path untouched.
Passes optional window bounds from the Python wrapper down to the Triton kernel so forward kernels can reuse the global path for local attention windows.
Copilot AI review requested due to automatic review settings January 16, 2026 12:08
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the flash sparse attention implementation by simplifying the template parameter structure, removing separate mask and bias flags in favor of a unified bias parameter handling approach. The changes eliminate redundant template instantiations and consolidate multiple parameter structures into a single unified structure.

Changes:

  • Removed HAS_MASK and HAS_BIAS template parameters, reducing template instantiations from 8 combinations to 2 (causal/non-causal only)
  • Merged Mask_params and Bias_params into QKVB_params structure with unified bias handling
  • Updated kernel generation script to reflect simplified template structure

Reviewed changes

Copilot reviewed 299 out of 321 changed files in this pull request and generated no comments.

Show a summary per file
File Description
csrc/flash_sparse_attn/src/flash.h Consolidated parameter structures and simplified template signatures
csrc/flash_sparse_attn/src/generate_kernels.py Updated kernel generation to remove mask/bias template parameters
csrc/flash_sparse_attn/src/hardware_info.h Reformatted function calls for improved readability
csrc/flash_sparse_attn/src/instantiations/*.cu Removed files with mask/bias combinations and updated remaining files with simplified template parameters
CITATION.cff Updated project title

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@LoserCheems LoserCheems merged commit 906846d into main Jan 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants