-
Notifications
You must be signed in to change notification settings - Fork 49
Refactor create_mask function parameters #218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors the create_mask function's parameter signature to improve clarity by reordering parameters and making several optional. The changes include removing batch_size and key_len as explicit parameters (now derived from attention_bias), moving query_len to the second position, and converting most parameters to keyword-only with defaults. The version is also bumped from 1.2.3 to 1.2.4.
Key Changes:
- Simplified
create_maskfunction signature by removingbatch_sizeandkey_lenparameters (now derived internally fromattention_bias) - Reordered parameters:
query_lenmoved to 2nd position,typemade 3rd with default, other parameters converted to keyword-only with defaults - Updated version number to 1.2.4
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| flash_sparse_attn/utils/mask.py | Refactored create_mask function signature to remove batch_size and key_len parameters, reorder remaining parameters, and derive batch/key dimensions from attention_bias shape |
| flash_sparse_attn/init.py | Version bump from 1.2.3 to 1.2.4 |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def create_mask( | ||
| attention_bias: torch.Tensor, | ||
| attention_mask: Optional[torch.Tensor], | ||
| batch_size: int, | ||
| query_len: int, | ||
| key_len: int, | ||
| window_size: Optional[int], | ||
| min_dtype: Optional[float], | ||
| block_size: Optional[int], | ||
| type: str = "topk", | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| window_size: Optional[int] = None, | ||
| min_dtype: Optional[float] = None, | ||
| block_size: Optional[int] = None, | ||
| ) -> torch.Tensor: |
Copilot
AI
Dec 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR description states "No backward compatibility issues; the function signature has changed but maintains the same functionality." However, this refactor introduces breaking changes to the API:
- The parameter order has changed:
query_lenmoved from 4th to 2nd position, andtypemoved from keyword-only to 3rd position with a default - Required parameters
batch_sizeandkey_lenhave been removed entirely attention_maskmoved from 2nd positional to a keyword-only parameter
This means existing code using positional arguments will break. For example, old calls like:
create_mask(bias, mask, batch_sz, q_len, k_len, window, min_val, type="topk")
Will fail with the new signature because the 2nd argument (which would be mask) will be interpreted as query_len (expecting an int).
Multiple call sites in the codebase have not been updated, including those in:
- benchmarks/backward_equivalence.py (lines 101, 172, 248, 321)
- benchmarks/forward_equivalence.py (lines 101, 163, 230, 302)
- benchmarks/backward_performance.py (lines 107, 189, 268, 345)
- benchmarks/forward_performance.py (lines 108, 185, 261, 341)
- flash_sparse_attn/integrations/modeling_flash_sparse_attention_utils.py (line 649)
Consider either updating all call sites in the same PR or providing a deprecation path to maintain backward compatibility.
|
|
||
| # If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len) | ||
| if attention_mask is not None and attention_mask.dim() == 2: | ||
| batch_size, key_len = attention_bias.shape[0], attention_bias.shape[-1] |
Copilot
AI
Dec 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variables batch_size and key_len are now derived from attention_bias.shape instead of being passed as parameters. While this simplifies the API, it could potentially cause issues if attention_bias has been broadcast or has unexpected dimensions. Consider adding a validation check to ensure attention_bias has the expected rank (4 dimensions) before extracting shape values, especially since the docstring indicates some dimensions can be 1 (broadcast dimensions).
| batch_size (int): The batch size. | ||
| query_len (int): The sequence length of the query. | ||
| key_len (int): The sequence length of the key. | ||
| window_size (Optional[int]): The number of top elements to consider for the attention mask. |
Copilot
AI
Dec 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation for window_size describes it as "The number of top elements to consider for the attention mask" but doesn't indicate that it's required when type="topk". Since topk_mask requires a non-None window_size parameter, the documentation should clarify this requirement. Consider updating the description to something like: "The number of top elements to consider for the attention mask. Required when type='topk', ignored when type='relu'."
| window_size (Optional[int]): The number of top elements to consider for the attention mask. | |
| window_size (Optional[int]): The number of top elements to consider for the attention mask. Required when type='topk', ignored when type='relu'. |
| key_len (int): The sequence length of the key. | ||
| window_size (Optional[int]): The number of top elements to consider for the attention mask. | ||
| min_dtype (Optional[float]): The minimum value to use for masking. | ||
| block_size (Optional[int]): Optional size of aggregation blocks after top-k masking. |
Copilot
AI
Dec 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation describes block_size as being "after top-k masking", but this parameter is actually used by both topk_mask and relu_mask functions. The description should be updated to reflect that it applies to both mask types, not just top-k. Consider: "Optional size of aggregation blocks to smooth the resulting mask along the key dimension."
| block_size (Optional[int]): Optional size of aggregation blocks after top-k masking. | |
| block_size (Optional[int]): Optional size of aggregation blocks to smooth the resulting mask along the key dimension. |
| from typing import Optional | ||
|
|
||
| __version__ = "1.2.3" | ||
| __version__ = "1.2.4" |
Copilot
AI
Dec 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The version is being bumped from 1.2.3 to 1.2.4 (a patch version increment), but this PR introduces breaking changes to the create_mask API. According to semantic versioning, breaking changes should result in a major version bump (2.0.0) or at minimum a minor version bump (1.3.0) if the API is still in the 0.x or 1.x phase and breaking changes are acceptable. A patch version (1.2.4) should only be used for backward-compatible bug fixes. Since the function signature change breaks existing callers that use positional arguments, this warrants a more significant version change.
| __version__ = "1.2.4" | |
| __version__ = "2.0.0" |
| window_size: Optional[int] = None, | ||
| min_dtype: Optional[float] = None, | ||
| block_size: Optional[int] = None, |
Copilot
AI
Dec 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The window_size parameter is now optional with a default value of None, but when type="topk", this parameter is required by the topk_mask function at line 223. If a caller uses type="topk" without providing window_size, the code will fail with a TypeError when trying to pass None to topk_mask which expects an int. Consider either:
- Making
window_sizerequired whentype="topk"by adding validation, or - Providing a sensible default value for
window_size, or - Making the
topk_maskfunction handleNonegracefully.
Summary
create_maskfunction parameters and update the version number.Root Cause
create_masklacked clarity, making it difficult to understand the function's usage.Changes
create_maskfunction for better clarity and consistency.Reproduction
Tests
Compatibility
Checklist