Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flash_sparse_attn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Optional

__version__ = "1.2.3"
__version__ = "1.2.4"
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The version is being bumped from 1.2.3 to 1.2.4 (a patch version increment), but this PR introduces breaking changes to the create_mask API. According to semantic versioning, breaking changes should result in a major version bump (2.0.0) or at minimum a minor version bump (1.3.0) if the API is still in the 0.x or 1.x phase and breaking changes are acceptable. A patch version (1.2.4) should only be used for backward-compatible bug fixes. Since the function signature change breaks existing callers that use positional arguments, this warrants a more significant version change.

Suggested change
__version__ = "1.2.4"
__version__ = "2.0.0"

Copilot uses AI. Check for mistakes.


# Import CUDA functions when available
Expand Down
17 changes: 7 additions & 10 deletions flash_sparse_attn/utils/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,12 @@ def relu_mask(

def create_mask(
attention_bias: torch.Tensor,
attention_mask: Optional[torch.Tensor],
batch_size: int,
query_len: int,
key_len: int,
window_size: Optional[int],
min_dtype: Optional[float],
block_size: Optional[int],
type: str = "topk",
attention_mask: Optional[torch.Tensor] = None,
window_size: Optional[int] = None,
min_dtype: Optional[float] = None,
block_size: Optional[int] = None,
Comment on lines +169 to +171
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The window_size parameter is now optional with a default value of None, but when type="topk", this parameter is required by the topk_mask function at line 223. If a caller uses type="topk" without providing window_size, the code will fail with a TypeError when trying to pass None to topk_mask which expects an int. Consider either:

  1. Making window_size required when type="topk" by adding validation, or
  2. Providing a sensible default value for window_size, or
  3. Making the topk_mask function handle None gracefully.

Copilot uses AI. Check for mistakes.
) -> torch.Tensor:
Comment on lines 164 to 172
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description states "No backward compatibility issues; the function signature has changed but maintains the same functionality." However, this refactor introduces breaking changes to the API:

  1. The parameter order has changed: query_len moved from 4th to 2nd position, and type moved from keyword-only to 3rd position with a default
  2. Required parameters batch_size and key_len have been removed entirely
  3. attention_mask moved from 2nd positional to a keyword-only parameter

This means existing code using positional arguments will break. For example, old calls like:
create_mask(bias, mask, batch_sz, q_len, k_len, window, min_val, type="topk")

Will fail with the new signature because the 2nd argument (which would be mask) will be interpreted as query_len (expecting an int).

Multiple call sites in the codebase have not been updated, including those in:

  • benchmarks/backward_equivalence.py (lines 101, 172, 248, 321)
  • benchmarks/forward_equivalence.py (lines 101, 163, 230, 302)
  • benchmarks/backward_performance.py (lines 107, 189, 268, 345)
  • benchmarks/forward_performance.py (lines 108, 185, 261, 341)
  • flash_sparse_attn/integrations/modeling_flash_sparse_attention_utils.py (line 649)

Consider either updating all call sites in the same PR or providing a deprecation path to maintain backward compatibility.

Copilot uses AI. Check for mistakes.
r"""
This function creates a mask tensor for Flash Sparse Attention.
Expand All @@ -180,15 +178,13 @@ def create_mask(
Args:
attention_bias (torch.Tensor): The attention bias tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
query_len (int): The sequence length of the query.
type (str): The type of mask to create. Options are "topk" and "relu".
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
(batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
batch_size (int): The batch size.
query_len (int): The sequence length of the query.
key_len (int): The sequence length of the key.
window_size (Optional[int]): The number of top elements to consider for the attention mask.
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation for window_size describes it as "The number of top elements to consider for the attention mask" but doesn't indicate that it's required when type="topk". Since topk_mask requires a non-None window_size parameter, the documentation should clarify this requirement. Consider updating the description to something like: "The number of top elements to consider for the attention mask. Required when type='topk', ignored when type='relu'."

Suggested change
window_size (Optional[int]): The number of top elements to consider for the attention mask.
window_size (Optional[int]): The number of top elements to consider for the attention mask. Required when type='topk', ignored when type='relu'.

Copilot uses AI. Check for mistakes.
min_dtype (Optional[float]): The minimum value to use for masking.
block_size (Optional[int]): Optional size of aggregation blocks after top-k masking.
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation describes block_size as being "after top-k masking", but this parameter is actually used by both topk_mask and relu_mask functions. The description should be updated to reflect that it applies to both mask types, not just top-k. Consider: "Optional size of aggregation blocks to smooth the resulting mask along the key dimension."

Suggested change
block_size (Optional[int]): Optional size of aggregation blocks after top-k masking.
block_size (Optional[int]): Optional size of aggregation blocks to smooth the resulting mask along the key dimension.

Copilot uses AI. Check for mistakes.
type (str): The type of mask to create. Options are "topk" and "relu".

Returns:
attention (Tensor): The attention mask tensor of shape
Expand All @@ -200,6 +196,7 @@ def create_mask(

# If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len)
if attention_mask is not None and attention_mask.dim() == 2:
batch_size, key_len = attention_bias.shape[0], attention_bias.shape[-1]
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variables batch_size and key_len are now derived from attention_bias.shape instead of being passed as parameters. While this simplifies the API, it could potentially cause issues if attention_bias has been broadcast or has unexpected dimensions. Consider adding a validation check to ensure attention_bias has the expected rank (4 dimensions) before extracting shape values, especially since the docstring indicates some dimensions can be 1 (broadcast dimensions).

Copilot uses AI. Check for mistakes.
if attention_mask.shape[-1] == key_len:
attention_mask = attention_mask.view(batch_size, 1, 1, key_len)
elif attention_mask.shape[-1] == query_len:
Expand Down
Loading