Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Summary

  • Improve clarity and consistency of the create_mask function parameters and update the version number.

Root Cause

  • The original parameter structure of create_mask lacked clarity, making it difficult to understand the function's usage.

Changes

  • Refactored parameters of the create_mask function for better clarity and consistency.
  • Updated the version number to 1.2.4.

Reproduction

  • Not applicable as this is a refactor and version bump.

Tests

  • No new tests added; existing tests remain valid.

Compatibility

  • No backward compatibility issues; the function signature has changed but maintains the same functionality.

Checklist

  • Linked issue provided
  • Adds or updates tests
  • Updates docs if needed
  • No perf regressions

Copilot AI review requested due to automatic review settings December 20, 2025 14:01
@LoserCheems LoserCheems merged commit bd824a7 into main Dec 20, 2025
7 of 8 checks passed
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the create_mask function's parameter signature to improve clarity by reordering parameters and making several optional. The changes include removing batch_size and key_len as explicit parameters (now derived from attention_bias), moving query_len to the second position, and converting most parameters to keyword-only with defaults. The version is also bumped from 1.2.3 to 1.2.4.

Key Changes:

  • Simplified create_mask function signature by removing batch_size and key_len parameters (now derived internally from attention_bias)
  • Reordered parameters: query_len moved to 2nd position, type made 3rd with default, other parameters converted to keyword-only with defaults
  • Updated version number to 1.2.4

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.

File Description
flash_sparse_attn/utils/mask.py Refactored create_mask function signature to remove batch_size and key_len parameters, reorder remaining parameters, and derive batch/key dimensions from attention_bias shape
flash_sparse_attn/init.py Version bump from 1.2.3 to 1.2.4

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 164 to 172
def create_mask(
attention_bias: torch.Tensor,
attention_mask: Optional[torch.Tensor],
batch_size: int,
query_len: int,
key_len: int,
window_size: Optional[int],
min_dtype: Optional[float],
block_size: Optional[int],
type: str = "topk",
attention_mask: Optional[torch.Tensor] = None,
window_size: Optional[int] = None,
min_dtype: Optional[float] = None,
block_size: Optional[int] = None,
) -> torch.Tensor:
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description states "No backward compatibility issues; the function signature has changed but maintains the same functionality." However, this refactor introduces breaking changes to the API:

  1. The parameter order has changed: query_len moved from 4th to 2nd position, and type moved from keyword-only to 3rd position with a default
  2. Required parameters batch_size and key_len have been removed entirely
  3. attention_mask moved from 2nd positional to a keyword-only parameter

This means existing code using positional arguments will break. For example, old calls like:
create_mask(bias, mask, batch_sz, q_len, k_len, window, min_val, type="topk")

Will fail with the new signature because the 2nd argument (which would be mask) will be interpreted as query_len (expecting an int).

Multiple call sites in the codebase have not been updated, including those in:

  • benchmarks/backward_equivalence.py (lines 101, 172, 248, 321)
  • benchmarks/forward_equivalence.py (lines 101, 163, 230, 302)
  • benchmarks/backward_performance.py (lines 107, 189, 268, 345)
  • benchmarks/forward_performance.py (lines 108, 185, 261, 341)
  • flash_sparse_attn/integrations/modeling_flash_sparse_attention_utils.py (line 649)

Consider either updating all call sites in the same PR or providing a deprecation path to maintain backward compatibility.

Copilot uses AI. Check for mistakes.

# If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len)
if attention_mask is not None and attention_mask.dim() == 2:
batch_size, key_len = attention_bias.shape[0], attention_bias.shape[-1]
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variables batch_size and key_len are now derived from attention_bias.shape instead of being passed as parameters. While this simplifies the API, it could potentially cause issues if attention_bias has been broadcast or has unexpected dimensions. Consider adding a validation check to ensure attention_bias has the expected rank (4 dimensions) before extracting shape values, especially since the docstring indicates some dimensions can be 1 (broadcast dimensions).

Copilot uses AI. Check for mistakes.
batch_size (int): The batch size.
query_len (int): The sequence length of the query.
key_len (int): The sequence length of the key.
window_size (Optional[int]): The number of top elements to consider for the attention mask.
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation for window_size describes it as "The number of top elements to consider for the attention mask" but doesn't indicate that it's required when type="topk". Since topk_mask requires a non-None window_size parameter, the documentation should clarify this requirement. Consider updating the description to something like: "The number of top elements to consider for the attention mask. Required when type='topk', ignored when type='relu'."

Suggested change
window_size (Optional[int]): The number of top elements to consider for the attention mask.
window_size (Optional[int]): The number of top elements to consider for the attention mask. Required when type='topk', ignored when type='relu'.

Copilot uses AI. Check for mistakes.
key_len (int): The sequence length of the key.
window_size (Optional[int]): The number of top elements to consider for the attention mask.
min_dtype (Optional[float]): The minimum value to use for masking.
block_size (Optional[int]): Optional size of aggregation blocks after top-k masking.
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation describes block_size as being "after top-k masking", but this parameter is actually used by both topk_mask and relu_mask functions. The description should be updated to reflect that it applies to both mask types, not just top-k. Consider: "Optional size of aggregation blocks to smooth the resulting mask along the key dimension."

Suggested change
block_size (Optional[int]): Optional size of aggregation blocks after top-k masking.
block_size (Optional[int]): Optional size of aggregation blocks to smooth the resulting mask along the key dimension.

Copilot uses AI. Check for mistakes.
from typing import Optional

__version__ = "1.2.3"
__version__ = "1.2.4"
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The version is being bumped from 1.2.3 to 1.2.4 (a patch version increment), but this PR introduces breaking changes to the create_mask API. According to semantic versioning, breaking changes should result in a major version bump (2.0.0) or at minimum a minor version bump (1.3.0) if the API is still in the 0.x or 1.x phase and breaking changes are acceptable. A patch version (1.2.4) should only be used for backward-compatible bug fixes. Since the function signature change breaks existing callers that use positional arguments, this warrants a more significant version change.

Suggested change
__version__ = "1.2.4"
__version__ = "2.0.0"

Copilot uses AI. Check for mistakes.
Comment on lines +169 to +171
window_size: Optional[int] = None,
min_dtype: Optional[float] = None,
block_size: Optional[int] = None,
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The window_size parameter is now optional with a default value of None, but when type="topk", this parameter is required by the topk_mask function at line 223. If a caller uses type="topk" without providing window_size, the code will fail with a TypeError when trying to pass None to topk_mask which expects an int. Consider either:

  1. Making window_size required when type="topk" by adding validation, or
  2. Providing a sensible default value for window_size, or
  3. Making the topk_mask function handle None gracefully.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants