-
Notifications
You must be signed in to change notification settings - Fork 40
[FEATURE SUPPORT] Optional mask/bias (3D & 4D) #170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Introduces h_h_mask_ratio and h_h_bias_ratio fields to precompute head ratios for mask and bias parameters, following the existing pattern used for query/key head ratios. Also adds total_k dimension field and includes TODO comment for potential block mask memory optimization.
Reorders stride and head parameter assignments for better logical grouping and consistency between forward and backward pass implementations. Adds missing ratio calculations for mask and bias heads to complement existing head-to-key ratio computation. Fixes indentation inconsistency in batch stride calculations.
Removes redundant conditional logic for computing head indices by directly using the ratio-based calculations inline. Previously used intermediate variables with complex conditional expressions that duplicated the same logic pattern. Now directly computes head indices using the division by ratio parameters, making the code more readable and eliminating unnecessary variables.
Removes conditional logic for computing head indices and replaces it with direct ratio-based calculations using h_h_mask_ratio and h_h_bias_ratio parameters. This eliminates the need for intermediate variables and conditional branches, making the code more straightforward and potentially improving performance.
Introduces has_mask and has_bias boolean fields to Mask_params and Bias_params structures respectively. These flags enable runtime detection of whether mask or bias parameters are present, improving conditional logic handling and potentially optimizing performance by avoiding unnecessary processing when these optional components are not used.
Introduces optional mask and bias parameters to prevent accessing null tensors when these features are disabled. Previously, mask and bias tensors were always accessed regardless of whether they contained valid data, which could cause errors or undefined behavior. Now uses conditional checks to only access tensor data and stride information when the corresponding features are actually enabled, improving robustness and allowing for optional mask/bias functionality.
…gs for conditional bias management
Changes mask and bias parameters from required tensor references to optional tensor references in both forward and backward multi-head attention functions. Improves API flexibility by allowing these attention modifiers to be omitted when not needed, reducing memory overhead and simplifying function calls for basic attention operations. Updates parameter comments to use consistent formatting with curly brace notation for dimension alternatives.
Converts mandatory mask and bias parameters to optional parameters by wrapping them in std::optional. Adds proper validation and initialization logic to handle cases where mask or bias are not provided, creating empty tensors as placeholders when needed. Updates all related tensor operations, shape checking, and kernel parameter passing to conditionally process these optional inputs throughout both forward and backward passes.
Refactors template signatures across forward, backward, and split-KV kernels to include additional boolean parameters for mask and bias support. Updates all kernel instantiations to use expanded template parameters, maintaining compatibility with existing causal-only configurations while enabling new combinations of mask and bias features. Removes formatting inconsistencies in include statements and standardizes template parameter ordering across all instantiation files.
Extends the flash attention kernel generator to support additional template parameters for masking and bias operations. Updates all kernel templates to include HAS_MASK and HAS_BIAS parameters, allowing for more flexible attention implementations with optional masking and bias addition. Modifies the kernel filename generation to include mask and bias flags for better organization and identification of generated kernel variants. Changes the default output directory to a more structured path and ensures directory creation before writing files.
Extends forward and backward multi-head attention function templates with Has_mask and Has_bias template parameters to enable conditional mask and bias functionality during attention computation.
Extends both forward and backward multi-head attention kernels to support additional mask and bias parameters through new template arguments. Enhances kernel flexibility by allowing attention mechanisms to handle custom masking patterns and bias terms beyond just causal masking.
Introduces template parameters to conditionally compile mask and bias operations, enabling performance optimizations when these features are not needed. Replaces runtime checks with constexpr conditions to eliminate unnecessary computations and memory accesses when mask or bias are disabled. Improves mask condition logic from float comparison to boolean evaluation for more reliable masking behavior.
Extends kernel templates with Has_mask and Has_bias boolean parameters to enable attention masking and bias functionality. Updates all kernel function signatures and template instantiations to accommodate the new parameters while maintaining backward compatibility. Removes commented debug code and consolidates CUDA function attribute setting for improved code clarity.
Extends backward kernel templates with Has_mask and Has_bias parameters to enable attention masking and bias functionality during gradient computation. Updates all kernel instantiations and function signatures to propagate the new template parameters through the call chain, maintaining consistency across different head dimensions and device configurations. Includes minor code formatting improvements for better readability.
Extends the template parameter list to include Has_mask and Has_bias flags for better flexibility in handling attention mechanisms with masks and biases. Updates all function calls to pass through the new template parameters while maintaining backward compatibility with existing functionality.
Extends kernel templates with Has_mask and Has_bias boolean parameters to support attention masking and bias operations. Updates all affected function signatures and call sites to maintain consistency across the attention computation pipeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds optional attention mask and bias support to Flash Dynamic Mask Attention with both 3D and 4D tensor formats. The implementation extends the template dispatch system to handle mask/bias presence through compile-time boolean flags, enabling specialized kernel generation for different combinations of causal, mask, and bias configurations.
- Extends template parameters with
Has_maskandHas_biasboolean flags for compile-time optimization - Adds support for 3D tensor broadcast over query dimension alongside existing 4D format
- Generates specialized kernel instantiations for all mask/bias combinations
Reviewed Changes
Copilot reviewed 296 out of 296 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| csrc/flash_dmattn/src/mask.h | Core mask application logic with template specialization for different mask/bias combinations |
| csrc/flash_dmattn/src/instantiations/*.cu | Updated kernel instantiations with additional template parameters for mask/bias flags |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| // Without the "make_coord" we get wrong results | ||
| auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); | ||
| // Apply scaling and bias or masking | ||
| tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) |
Copilot
AI
Sep 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mask condition logic has changed from mask(coord) == 0.0f to !mask(coord). This assumes the mask tensor contains boolean-like values, but the change should be documented or verified to ensure compatibility with existing mask formats.
| // Splitting the different head dimensions to different files to speed up compilation. | ||
| // This file is auto-generated. See "generate_kernels.py" | ||
| #include "namespace_config.h" | ||
| // This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" |
Copilot
AI
Sep 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing newline between the comment and the #include directive. The comment and include statement should be on separate lines.
| // This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" | |
| // This file is auto-generated. See "generate_kernels.py" | |
| #include "namespace_config.h" |
Extends mask and bias tensor handling to accept 3-dimensional inputs by automatically expanding them to 4D format with a dummy seqlen_q dimension. Removes rigid shape validation checks that prevented flexible tensor dimensions and updates tensor creation logic to handle both 3D and 4D formats appropriately. Ensures backward pass correctly reduces the dummy dimension when original bias was 3D to maintain output shape consistency.
Introduces template parameters to conditionally enable mask and bias functionality in flash attention kernels. Optimizes shared memory allocation by only reserving space for mask and bias when actually needed, reducing memory footprint when these features are disabled.
Improves performance by using larger block sizes when masks and bias are not present. Uses adaptive block sizing strategy that considers head size to maximize throughput for cases without attention masks or bias terms.
Refactors flash attention kernel selection to use compile-time conditionals that specialize kernel configurations based on the presence of mask and bias operations. Updates block size calculations to use larger values when mask/bias are absent, improving performance for simpler attention patterns. Replaces runtime shared memory checks with more granular per-configuration memory thresholds, enabling better hardware utilization across different GPU architectures.
Optimizes memory usage and performance by making mask and bias operations conditional based on template parameters Has_mask and Has_bias. Prevents unnecessary memory allocation and computation when mask or bias features are not needed, reducing shared memory footprint and eliminating redundant operations. Updates tensor creation logic to avoid allocating memory for unused mask/bias tensors and wraps all mask/bias related operations in compile-time conditionals.
Replaces runtime conditionals with compile-time `constexpr` checks for mask and bias operations to improve performance by eliminating unnecessary computations when features are disabled. Reduces code duplication by creating specialized branches for different combinations of mask and bias availability, allowing the compiler to optimize out unused code paths. Eliminates redundant tensor copies and memory operations when mask or bias are not present, leading to better register usage and reduced memory bandwidth.
Prevents accidental modification of mask and bias parameters in the apply_mask function by making them const references, improving code safety and expressing intent more clearly.
Corrects the backward pass logic to properly handle bias tensors with missing sequence length dimension. Previously, 3D bias tensors were incorrectly expanded during gradient computation, leading to shape mismatches. Now properly detects when bias lacks the sequence length dimension and sums gradients across that dimension appropriately. Ensures gradient tensor is properly zeroed and handles both MQA/GQA cases and 3D bias tensors correctly.
…s in flash_dynamic_mask_attention_forward
…sk_attention_forward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 300 out of 301 changed files in this pull request and generated 2 comments.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| @@ -0,0 +1,10 @@ | |||
| // Copyright (c) 2025, Jingze Shi and Tri Dao. | |||
| // Splitting the different head dimensions to different files to speed up compilation. | |||
| // This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" | |||
Copilot
AI
Sep 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing newline between the comment and #include directive. Should be "generate_kernels.py"\n#include"
| // This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" | |
| // This file is auto-generated. See "generate_kernels.py" | |
| #include "namespace_config.h" |
| @@ -0,0 +1,10 @@ | |||
| // Copyright (c) 2025, Jingze Shi and Tri Dao. | |||
| // Splitting the different head dimensions to different files to speed up compilation. | |||
| // This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" | |||
Copilot
AI
Sep 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing newline between the comment and #include directive. Should be "generate_kernels.py"\n#include"
| // This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" | |
| // This file is auto-generated. See "generate_kernels.py" | |
| #include "namespace_config.h" |
Summary
Adds optional attention mask and bias to Flash Dynamic Mask Attention. Both can now be omitted or provided in either 4D broadcastable form (B, Hx, S_q, S_k) or compact 3D form (B, Hx, S_k) to broadcast over all query positions. Supports forward and backward paths without changing existing callers that already pass 4D tensors. Resolves issue #161.
Design
Changes
Public C++ / Python Extension:
Supported mask/bias head axis: 1 | num_heads_k | num_heads.
Implementation Notes
Tests
Recommended (added / to add):
Docs
Checklist
Follow-ups