Refactor create_mask function parameters#218
Conversation
There was a problem hiding this comment.
Pull request overview
This PR refactors the create_mask function's parameter signature to improve clarity by reordering parameters and making several optional. The changes include removing batch_size and key_len as explicit parameters (now derived from attention_bias), moving query_len to the second position, and converting most parameters to keyword-only with defaults. The version is also bumped from 1.2.3 to 1.2.4.
Key Changes:
- Simplified
create_maskfunction signature by removingbatch_sizeandkey_lenparameters (now derived internally fromattention_bias) - Reordered parameters:
query_lenmoved to 2nd position,typemade 3rd with default, other parameters converted to keyword-only with defaults - Updated version number to 1.2.4
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| flash_sparse_attn/utils/mask.py | Refactored create_mask function signature to remove batch_size and key_len parameters, reorder remaining parameters, and derive batch/key dimensions from attention_bias shape |
| flash_sparse_attn/init.py | Version bump from 1.2.3 to 1.2.4 |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def create_mask( | ||
| attention_bias: torch.Tensor, | ||
| attention_mask: Optional[torch.Tensor], | ||
| batch_size: int, | ||
| query_len: int, | ||
| key_len: int, | ||
| window_size: Optional[int], | ||
| min_dtype: Optional[float], | ||
| block_size: Optional[int], | ||
| type: str = "topk", | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| window_size: Optional[int] = None, | ||
| min_dtype: Optional[float] = None, | ||
| block_size: Optional[int] = None, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
The PR description states "No backward compatibility issues; the function signature has changed but maintains the same functionality." However, this refactor introduces breaking changes to the API:
- The parameter order has changed:
query_lenmoved from 4th to 2nd position, andtypemoved from keyword-only to 3rd position with a default - Required parameters
batch_sizeandkey_lenhave been removed entirely attention_maskmoved from 2nd positional to a keyword-only parameter
This means existing code using positional arguments will break. For example, old calls like:
create_mask(bias, mask, batch_sz, q_len, k_len, window, min_val, type="topk")
Will fail with the new signature because the 2nd argument (which would be mask) will be interpreted as query_len (expecting an int).
Multiple call sites in the codebase have not been updated, including those in:
- benchmarks/backward_equivalence.py (lines 101, 172, 248, 321)
- benchmarks/forward_equivalence.py (lines 101, 163, 230, 302)
- benchmarks/backward_performance.py (lines 107, 189, 268, 345)
- benchmarks/forward_performance.py (lines 108, 185, 261, 341)
- flash_sparse_attn/integrations/modeling_flash_sparse_attention_utils.py (line 649)
Consider either updating all call sites in the same PR or providing a deprecation path to maintain backward compatibility.
|
|
||
| # If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len) | ||
| if attention_mask is not None and attention_mask.dim() == 2: | ||
| batch_size, key_len = attention_bias.shape[0], attention_bias.shape[-1] |
There was a problem hiding this comment.
The variables batch_size and key_len are now derived from attention_bias.shape instead of being passed as parameters. While this simplifies the API, it could potentially cause issues if attention_bias has been broadcast or has unexpected dimensions. Consider adding a validation check to ensure attention_bias has the expected rank (4 dimensions) before extracting shape values, especially since the docstring indicates some dimensions can be 1 (broadcast dimensions).
| batch_size (int): The batch size. | ||
| query_len (int): The sequence length of the query. | ||
| key_len (int): The sequence length of the key. | ||
| window_size (Optional[int]): The number of top elements to consider for the attention mask. |
There was a problem hiding this comment.
The documentation for window_size describes it as "The number of top elements to consider for the attention mask" but doesn't indicate that it's required when type="topk". Since topk_mask requires a non-None window_size parameter, the documentation should clarify this requirement. Consider updating the description to something like: "The number of top elements to consider for the attention mask. Required when type='topk', ignored when type='relu'."
| window_size (Optional[int]): The number of top elements to consider for the attention mask. | |
| window_size (Optional[int]): The number of top elements to consider for the attention mask. Required when type='topk', ignored when type='relu'. |
| key_len (int): The sequence length of the key. | ||
| window_size (Optional[int]): The number of top elements to consider for the attention mask. | ||
| min_dtype (Optional[float]): The minimum value to use for masking. | ||
| block_size (Optional[int]): Optional size of aggregation blocks after top-k masking. |
There was a problem hiding this comment.
The documentation describes block_size as being "after top-k masking", but this parameter is actually used by both topk_mask and relu_mask functions. The description should be updated to reflect that it applies to both mask types, not just top-k. Consider: "Optional size of aggregation blocks to smooth the resulting mask along the key dimension."
| block_size (Optional[int]): Optional size of aggregation blocks after top-k masking. | |
| block_size (Optional[int]): Optional size of aggregation blocks to smooth the resulting mask along the key dimension. |
| from typing import Optional | ||
|
|
||
| __version__ = "1.2.3" | ||
| __version__ = "1.2.4" |
There was a problem hiding this comment.
The version is being bumped from 1.2.3 to 1.2.4 (a patch version increment), but this PR introduces breaking changes to the create_mask API. According to semantic versioning, breaking changes should result in a major version bump (2.0.0) or at minimum a minor version bump (1.3.0) if the API is still in the 0.x or 1.x phase and breaking changes are acceptable. A patch version (1.2.4) should only be used for backward-compatible bug fixes. Since the function signature change breaks existing callers that use positional arguments, this warrants a more significant version change.
| __version__ = "1.2.4" | |
| __version__ = "2.0.0" |
| window_size: Optional[int] = None, | ||
| min_dtype: Optional[float] = None, | ||
| block_size: Optional[int] = None, |
There was a problem hiding this comment.
The window_size parameter is now optional with a default value of None, but when type="topk", this parameter is required by the topk_mask function at line 223. If a caller uses type="topk" without providing window_size, the code will fail with a TypeError when trying to pass None to topk_mask which expects an int. Consider either:
- Making
window_sizerequired whentype="topk"by adding validation, or - Providing a sensible default value for
window_size, or - Making the
topk_maskfunction handleNonegracefully.
Summary
create_maskfunction parameters and update the version number.Root Cause
create_masklacked clarity, making it difficult to understand the function's usage.Changes
create_maskfunction for better clarity and consistency.Reproduction
Tests
Compatibility
Checklist