-
Notifications
You must be signed in to change notification settings - Fork 40
[FEATURE SUPPORT] Convert attention mask storage from float to bool #166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Introduces template parameters to enable converting boolean values to numeric elements during copy operations. Adds conditional logic that converts true values to 1.0f and false values to 0.0f when the Bool_to_Element flag is enabled, allowing for more flexible data type transformations in memory copy routines.
Changes mask pointer casting from generic Element to const bool for type safety. Updates copy_MN template calls to include Clear_OOB_MN and Bool_to_Element parameters for proper mask handling. Comments out async fence and wait operations, likely for debugging or performance optimization.
Changes mask pointer casting from generic Element to const bool type for proper type safety. Updates copy operations to include Bool_to_Element template parameter for correct boolean-to-element conversion. Replaces asynchronous copy fences with synchronous thread synchronization to ensure proper data consistency before mask operations.
Changes mask validation to require boolean dtype instead of matching query dtype across all attention functions. Comments out variable length forward and backward pass functions in the Python binding module.
Improves type consistency and performance by using torch.bool instead of generic dtype for attention masks. Eliminates unnecessary type conversions and simplifies mask comparison logic by using False instead of 0.0 comparisons.
Specifies that attention_bias parameter expects a float tensor to improve API documentation clarity and help developers understand the expected data type.
Replaces float-based attention mask operations with boolean dtype for improved memory efficiency and cleaner logic. Removes unnecessary dtype conversion and simplifies mask creation by using boolean tensors directly instead of converting comparison results to float values.
Eliminates variable-length sequence support to simplify the codebase and focus on standard batch-based attention operations. Removes forward and backward implementations for variable-length sequences along with their fake wrappers, reducing code complexity and maintenance overhead. Fixes mask and bias handling in the remaining implementation to properly handle None values during padding operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR converts attention mask storage from float16/bfloat16 to bool type to reduce memory footprint and bandwidth by ~50% while preserving numerical behavior. The masks are stored as bool in global memory and converted to compute dtype (Element) only when loading into shared memory.
- Convert mask tensor storage from float to bool in global memory
- Update CUDA kernels to use bool-to-Element conversion during gmem→smem transfers
- Enforce mask dtype validation to
torch.kBoolin C++ API - Update Python interface to normalize masks to bool and pad with
False
Reviewed Changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py | Updated mask creation to use bool dtype and fixed comparison operators |
| flash_dmattn/integrations/flash_dynamic_mask_attention.py | Updated docstring to clarify bias tensor type |
| flash_dmattn/flash_dmattn_interface.py | Removed varlen functions, updated mask padding logic and docstrings |
| csrc/flash_dmattn/src/utils.h | Added Bool_to_Element template parameter for mask conversion |
| csrc/flash_dmattn/src/flash_fwd_kernel.h | Updated mask pointers to bool and enabled Bool_to_Element conversion |
| csrc/flash_dmattn/src/flash_bwd_kernel.h | Updated mask pointers to bool and enabled Bool_to_Element conversion |
| csrc/flash_dmattn/flash_api.cpp | Changed mask dtype validation from q_dtype to torch::kBool |
| benchmarks/*.py | Updated mask creation to use bool dtype consistently |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Replaces verbose static_cast operations with more concise To_type constructor calls when converting boolean values to numeric types. Improves code readability while maintaining the same functionality of converting true to 1 and false to 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py
Outdated
Show resolved
Hide resolved
Replaces `== False` comparison with the more idiomatic `~` operator for boolean negation, improving code readability and following Python best practices.
Co-authored-by: Copilot <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.
Comments suppressed due to low confidence (1)
flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py:1
- This line converts attention_mask to the compute dtype, but the mask should remain as bool according to the PR's goal of storing masks as bool in global memory.
# Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| seqlen_k = k.shape[1] | ||
| is_grad = is_grad_enabled and any( | ||
| x.requires_grad for x in [q, k, v] | ||
| ) |
Copilot
AI
Sep 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variables batch_size and seqlen_q were removed but may still be needed later in the function. The code should verify these variables are not used elsewhere in the forward method.
Summary
#163
Design
bool(1 byte/element).boolto compute Element (fp16/bf16) and proceed with the same softmax/masking pipeline.Alternatives considered:
Changes
const bool*and usecopy_MN<..., Bool_to_Element=true>(...)to convert to Element in smem.copy_MNwith aBool_to_Elementcompile-time path (defaults keep old behavior). Maintains signature compatibility via defaults.mask.dtype() == torch::kBool.mask.dtype() == torch::kBool(aligned with kernel behavior).torch.boolif not already.False(and bias with -inf as before).Implementation Notes
Tests
Docs
attn_maskmust betorch.bool.False.Checklist