-
Notifications
You must be signed in to change notification settings - Fork 53
Refactor create_mask function parameters #218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -163,14 +163,12 @@ def relu_mask( | |||||
|
|
||||||
| def create_mask( | ||||||
| attention_bias: torch.Tensor, | ||||||
| attention_mask: Optional[torch.Tensor], | ||||||
| batch_size: int, | ||||||
| query_len: int, | ||||||
| key_len: int, | ||||||
| window_size: Optional[int], | ||||||
| min_dtype: Optional[float], | ||||||
| block_size: Optional[int], | ||||||
| type: str = "topk", | ||||||
| attention_mask: Optional[torch.Tensor] = None, | ||||||
| window_size: Optional[int] = None, | ||||||
| min_dtype: Optional[float] = None, | ||||||
| block_size: Optional[int] = None, | ||||||
|
Comment on lines
+169
to
+171
|
||||||
| ) -> torch.Tensor: | ||||||
|
Comment on lines
164
to
172
|
||||||
| r""" | ||||||
| This function creates a mask tensor for Flash Sparse Attention. | ||||||
|
|
@@ -180,15 +178,13 @@ def create_mask( | |||||
| Args: | ||||||
| attention_bias (torch.Tensor): The attention bias tensor of shape | ||||||
| ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). | ||||||
| query_len (int): The sequence length of the query. | ||||||
| type (str): The type of mask to create. Options are "topk" and "relu". | ||||||
| attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape | ||||||
| (batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). | ||||||
| batch_size (int): The batch size. | ||||||
| query_len (int): The sequence length of the query. | ||||||
| key_len (int): The sequence length of the key. | ||||||
| window_size (Optional[int]): The number of top elements to consider for the attention mask. | ||||||
|
||||||
| window_size (Optional[int]): The number of top elements to consider for the attention mask. | |
| window_size (Optional[int]): The number of top elements to consider for the attention mask. Required when type='topk', ignored when type='relu'. |
Copilot
AI
Dec 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation describes block_size as being "after top-k masking", but this parameter is actually used by both topk_mask and relu_mask functions. The description should be updated to reflect that it applies to both mask types, not just top-k. Consider: "Optional size of aggregation blocks to smooth the resulting mask along the key dimension."
| block_size (Optional[int]): Optional size of aggregation blocks after top-k masking. | |
| block_size (Optional[int]): Optional size of aggregation blocks to smooth the resulting mask along the key dimension. |
Copilot
AI
Dec 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variables batch_size and key_len are now derived from attention_bias.shape instead of being passed as parameters. While this simplifies the API, it could potentially cause issues if attention_bias has been broadcast or has unexpected dimensions. Consider adding a validation check to ensure attention_bias has the expected rank (4 dimensions) before extracting shape values, especially since the docstring indicates some dimensions can be 1 (broadcast dimensions).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The version is being bumped from 1.2.3 to 1.2.4 (a patch version increment), but this PR introduces breaking changes to the
create_maskAPI. According to semantic versioning, breaking changes should result in a major version bump (2.0.0) or at minimum a minor version bump (1.3.0) if the API is still in the 0.x or 1.x phase and breaking changes are acceptable. A patch version (1.2.4) should only be used for backward-compatible bug fixes. Since the function signature change breaks existing callers that use positional arguments, this warrants a more significant version change.