Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Summary

Adds optional attention mask and bias to Flash Dynamic Mask Attention. Both can now be omitted or provided in either 4D broadcastable form (B, Hx, S_q, S_k) or compact 3D form (B, Hx, S_k) to broadcast over all query positions. Supports forward and backward paths without changing existing callers that already pass 4D tensors. Resolves issue #161.

Design

  • Accept std::optionalat::Tensor for mask and bias in forward/backward APIs.
  • Support two layouts:
    • 4D: (batch, heads_variant, seqlen_q, seqlen_k)
    • 3D: (batch, heads_variant, seqlen_k) → internally expanded with a stride-0 broadcast dim at seqlen_q position.
  • Kernel template dispatch extended with Has_mask / Has_bias booleans; no runtime branching inside core math loops.
  • Parameter structs record has_mask / has_bias; if absent, pointers set to nullptr and all related strides set to 0.
  • Broadcasting implementation:
    • 3D inputs are turned into a 4D expanded view (unsqueeze + expand) for uniform downstream handling.
    • Row stride becomes 0 so kernel reuses the same memory per query row.
  • Backward:
    • dbias only allocated/computed when bias is provided.
    • For GQA/MQA (num_heads_bias != num_heads) head reduction performed.
    • For 3D bias inputs (broadcast over S_q), an additional reduction over the query dimension is applied after kernel accumulation.
  • Safety: zero-length (seqlen_k == 0 or seqlen_q == 0) paths set outputs / LSE to neutral values.
  • Softcap, split-K, causal handling unchanged.

Changes

Public C++ / Python Extension:

  • Forward: fwd(q, k, v, mask=None, bias=None, out=None, ...)
  • Backward: bwd(dout, q, k, v, mask=None, bias=None, out, softmax_lse, dq=None, dk=None, dv=None, dbias=None, ...)
  • Internal param setup functions extended with has_mask / has_bias flags.
  • Kernel instantiation generator now enumerates (Has_mask, Has_bias) combinations.

Supported mask/bias head axis: 1 | num_heads_k | num_heads.

Implementation Notes

  • Added has_mask / has_bias to Flash_fwd_params / Flash_bwd_params; template BOOL_SWITCH expands specialized kernels.
  • 3D broadcast relies on stride(-2)==0 after expand; row_stride=0 stored in params → safe reuse.
  • Prevented invalid memory access by guarding creation / copy of mask/bias tensors with if constexpr.
  • Backward dbias:
    • Uses contiguous 4D buffer when original passed as 3D to avoid writing into a stride-0 dimension.
    • Post-kernel reductions: (heads) then (seqlen_q broadcast) if needed.
  • Head & sequence reductions use at::sum_out to avoid extra allocations.
  • No performance regression expected for no-mask/no-bias path (dispatch selects specialized kernels without extra conditionals).

Tests

Recommended (added / to add):

  • Forward equivalence (no mask/bias vs previous behavior).
  • Forward with 3D mask (all True) equals no-mask output.
  • Forward with random bias (compare against reference softmax(QK^T * scale + bias)).
  • Backward grad check (finite differences) with:
    • No mask/bias
    • 4D bias only
    • 3D bias only
    • 4D mask + 4D bias
    • 3D mask + 3D bias
  • GQA / MQA cases (num_heads_k divides num_heads) with/without bias.
  • seqlen_q == 1 decoding swap path.
  • Edge: seqlen_k == 0, seqlen_q == 0 (output zero / LSE inf).
  • Broadcast reduction correctness: verify dbias shape & summed values for 3D bias input.

Docs

  • Update README / API docs to describe:
    • Optional mask / bias
    • Accepted shapes (3D & 4D)
    • Broadcast semantics over query length
    • Returned dbias shape (matches input broadcast rank: 3D in → 3D out after reduction).
  • Mention performance note: 3D broadcast introduces no material overhead.

Checklist

Follow-ups

  • Add benchmark comparing 3D vs 4D bias/mask memory throughput.
  • Consider lazy materialization of expanded bias for split-K if future optimization needed.
  • Optional: expose an API flag to return raw softmax P only when bias present to reduce allocations.

Introduces h_h_mask_ratio and h_h_bias_ratio fields to precompute head ratios for mask and bias parameters, following the existing pattern used for query/key head ratios.

Also adds total_k dimension field and includes TODO comment for potential block mask memory optimization.
Reorders stride and head parameter assignments for better logical grouping and consistency between forward and backward pass implementations.

Adds missing ratio calculations for mask and bias heads to complement existing head-to-key ratio computation.

Fixes indentation inconsistency in batch stride calculations.
Removes redundant conditional logic for computing head indices by directly using the ratio-based calculations inline.

Previously used intermediate variables with complex conditional expressions that duplicated the same logic pattern. Now directly computes head indices using the division by ratio parameters, making the code more readable and eliminating unnecessary variables.
Removes conditional logic for computing head indices and replaces it with direct ratio-based calculations using h_h_mask_ratio and h_h_bias_ratio parameters.

This eliminates the need for intermediate variables and conditional branches, making the code more straightforward and potentially improving performance.
Introduces has_mask and has_bias boolean fields to Mask_params and Bias_params structures respectively.

These flags enable runtime detection of whether mask or bias parameters are present, improving conditional logic handling and potentially optimizing performance by avoiding unnecessary processing when these optional components are not used.
Introduces optional mask and bias parameters to prevent accessing null tensors when these features are disabled.

Previously, mask and bias tensors were always accessed regardless of whether they contained valid data, which could cause errors or undefined behavior.

Now uses conditional checks to only access tensor data and stride information when the corresponding features are actually enabled, improving robustness and allowing for optional mask/bias functionality.
Changes mask and bias parameters from required tensor references to optional tensor references in both forward and backward multi-head attention functions.

Improves API flexibility by allowing these attention modifiers to be omitted when not needed, reducing memory overhead and simplifying function calls for basic attention operations.

Updates parameter comments to use consistent formatting with curly brace notation for dimension alternatives.
Converts mandatory mask and bias parameters to optional parameters by wrapping them in std::optional.

Adds proper validation and initialization logic to handle cases where mask or bias are not provided, creating empty tensors as placeholders when needed.

Updates all related tensor operations, shape checking, and kernel parameter passing to conditionally process these optional inputs throughout both forward and backward passes.
Refactors template signatures across forward, backward, and split-KV kernels to include additional boolean parameters for mask and bias support.

Updates all kernel instantiations to use expanded template parameters, maintaining compatibility with existing causal-only configurations while enabling new combinations of mask and bias features.

Removes formatting inconsistencies in include statements and standardizes template parameter ordering across all instantiation files.
Extends the flash attention kernel generator to support additional template parameters for masking and bias operations.

Updates all kernel templates to include HAS_MASK and HAS_BIAS parameters, allowing for more flexible attention implementations with optional masking and bias addition.

Modifies the kernel filename generation to include mask and bias flags for better organization and identification of generated kernel variants.

Changes the default output directory to a more structured path and ensures directory creation before writing files.
Extends forward and backward multi-head attention function templates with Has_mask and Has_bias template parameters to enable conditional mask and bias functionality during attention computation.
Extends both forward and backward multi-head attention kernels to support additional mask and bias parameters through new template arguments.

Enhances kernel flexibility by allowing attention mechanisms to handle custom masking patterns and bias terms beyond just causal masking.
Introduces template parameters to conditionally compile mask and bias operations, enabling performance optimizations when these features are not needed.

Replaces runtime checks with constexpr conditions to eliminate unnecessary computations and memory accesses when mask or bias are disabled.

Improves mask condition logic from float comparison to boolean evaluation for more reliable masking behavior.
Extends kernel templates with Has_mask and Has_bias boolean parameters
to enable attention masking and bias functionality.

Updates all kernel function signatures and template instantiations
to accommodate the new parameters while maintaining backward compatibility.

Removes commented debug code and consolidates CUDA function attribute
setting for improved code clarity.
Extends backward kernel templates with Has_mask and Has_bias parameters to enable attention masking and bias functionality during gradient computation.

Updates all kernel instantiations and function signatures to propagate the new template parameters through the call chain, maintaining consistency across different head dimensions and device configurations.

Includes minor code formatting improvements for better readability.
Extends the template parameter list to include Has_mask and Has_bias flags
for better flexibility in handling attention mechanisms with masks and biases.

Updates all function calls to pass through the new template parameters
while maintaining backward compatibility with existing functionality.
Extends kernel templates with Has_mask and Has_bias boolean parameters to support attention masking and bias operations.

Updates all affected function signatures and call sites to maintain consistency across the attention computation pipeline.
Copilot AI review requested due to automatic review settings September 18, 2025 14:31
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds optional attention mask and bias support to Flash Dynamic Mask Attention with both 3D and 4D tensor formats. The implementation extends the template dispatch system to handle mask/bias presence through compile-time boolean flags, enabling specialized kernel generation for different combinations of causal, mask, and bias configurations.

  • Extends template parameters with Has_mask and Has_bias boolean flags for compile-time optimization
  • Adds support for 3D tensor broadcast over query dimension alongside existing 4D format
  • Generates specialized kernel instantiations for all mask/bias combinations

Reviewed Changes

Copilot reviewed 296 out of 296 changed files in this pull request and generated 2 comments.

File Description
csrc/flash_dmattn/src/mask.h Core mask application logic with template specialization for different mask/bias combinations
csrc/flash_dmattn/src/instantiations/*.cu Updated kernel instantiations with additional template parameters for mask/bias flags

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

// Without the "make_coord" we get wrong results
auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));
// Apply scaling and bias or masking
tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord))
Copy link

Copilot AI Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask condition logic has changed from mask(coord) == 0.0f to !mask(coord). This assumes the mask tensor contains boolean-like values, but the change should be documented or verified to ensure compatibility with existing mask formats.

Copilot uses AI. Check for mistakes.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h"
Copy link

Copilot AI Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing newline between the comment and the #include directive. The comment and include statement should be on separate lines.

Suggested change
// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h"
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"

Copilot uses AI. Check for mistakes.
Extends mask and bias tensor handling to accept 3-dimensional inputs by automatically expanding them to 4D format with a dummy seqlen_q dimension.

Removes rigid shape validation checks that prevented flexible tensor dimensions and updates tensor creation logic to handle both 3D and 4D formats appropriately.

Ensures backward pass correctly reduces the dummy dimension when original bias was 3D to maintain output shape consistency.
Introduces template parameters to conditionally enable mask and bias functionality in flash attention kernels.

Optimizes shared memory allocation by only reserving space for mask and bias when actually needed, reducing memory footprint when these features are disabled.
Improves performance by using larger block sizes when masks and bias are not present.

Uses adaptive block sizing strategy that considers head size to maximize throughput for cases without attention masks or bias terms.
Refactors flash attention kernel selection to use compile-time conditionals that specialize kernel configurations based on the presence of mask and bias operations.

Updates block size calculations to use larger values when mask/bias are absent, improving performance for simpler attention patterns.

Replaces runtime shared memory checks with more granular per-configuration memory thresholds, enabling better hardware utilization across different GPU architectures.
Optimizes memory usage and performance by making mask and bias operations conditional based on template parameters Has_mask and Has_bias.

Prevents unnecessary memory allocation and computation when mask or bias features are not needed, reducing shared memory footprint and eliminating redundant operations.

Updates tensor creation logic to avoid allocating memory for unused mask/bias tensors and wraps all mask/bias related operations in compile-time conditionals.
Replaces runtime conditionals with compile-time `constexpr` checks for mask and bias operations to improve performance by eliminating unnecessary computations when features are disabled.

Reduces code duplication by creating specialized branches for different combinations of mask and bias availability, allowing the compiler to optimize out unused code paths.

Eliminates redundant tensor copies and memory operations when mask or bias are not present, leading to better register usage and reduced memory bandwidth.
Prevents accidental modification of mask and bias parameters in the apply_mask function by making them const references, improving code safety and expressing intent more clearly.
Corrects the backward pass logic to properly handle bias tensors with missing sequence length dimension.

Previously, 3D bias tensors were incorrectly expanded during gradient computation, leading to shape mismatches. Now properly detects when bias lacks the sequence length dimension and sums gradients across that dimension appropriately.

Ensures gradient tensor is properly zeroed and handles both MQA/GQA cases and 3D bias tensors correctly.
@LoserCheems LoserCheems requested a review from Copilot September 19, 2025 17:00
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 300 out of 301 changed files in this pull request and generated 2 comments.


Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@@ -0,0 +1,10 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h"
Copy link

Copilot AI Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing newline between the comment and #include directive. Should be "generate_kernels.py"\n#include"

Suggested change
// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h"
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,10 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h"
Copy link

Copilot AI Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing newline between the comment and #include directive. Should be "generate_kernels.py"\n#include"

Suggested change
// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h"
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 483b726 into main Sep 19, 2025
@LoserCheems LoserCheems deleted the make-mask/bias-optional branch November 13, 2025 04:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants