Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 42 additions & 42 deletions flash_dmattn/utils/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,35 @@
import torch


def block_smooth(
attention_mask: torch.Tensor,
key_len: int,
block_size: int,
):
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The newly introduced block_smooth function is missing a docstring. Add documentation explaining its purpose, parameters, and return value to maintain consistency with other functions in the module like topk_mask and relu_mask.

Suggested change
):
):
"""
Smooths the attention mask by aggregating values in non-overlapping blocks along the key dimension.
For each block of size `block_size`, the block is set to True if more than half of its elements are True,
otherwise it is set to False. The tail (if key_len is not divisible by block_size) is handled similarly.
Args:
attention_mask (torch.Tensor): Boolean tensor of shape (..., key_len) representing the attention mask.
key_len (int): The length of the key dimension.
block_size (int): The size of each block for smoothing. Must be a positive integer.
Returns:
torch.Tensor: The smoothed attention mask tensor of the same shape as the input.
"""

Copilot uses AI. Check for mistakes.
if block_size <= 0:
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation only checks if block_size <= 0, but doesn't verify that block_size is an integer. The previous code included int(block_size) != block_size checks in both topk_mask and relu_mask. Consider adding integer validation here: if not isinstance(block_size, int) or block_size <= 0: to maintain the same input validation behavior.

Suggested change
if block_size <= 0:
if not isinstance(block_size, int) or block_size <= 0:

Copilot uses AI. Check for mistakes.
raise ValueError(f"block_size must be a positive integer, got {block_size}.")

if block_size > 1:
full_len = (key_len // block_size) * block_size

if full_len:
block_view = attention_mask[..., :full_len]
block_shape = (*block_view.shape[:-1], full_len // block_size, block_size)
blocks = block_view.view(*block_shape)
block_counts = blocks.sum(dim=-1).to(torch.int64)
block_keep = (block_counts * 2) > block_size
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))

if key_len > full_len:
tail_slice = attention_mask[..., full_len:]
tail_len = tail_slice.shape[-1]
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int64)
tail_keep = (tail_counts * 2) > tail_len
tail_slice.copy_(tail_keep.expand_as(tail_slice))

return attention_mask


Comment on lines +47 to +48
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is trailing whitespace on line 47. Remove the extra whitespace after the return statement.

Suggested change

Copilot uses AI. Check for mistakes.
def topk_mask(
attention_bias: torch.Tensor,
attention_mask: Optional[torch.Tensor],
Expand All @@ -42,14 +71,11 @@ def topk_mask(
attention_mask (Tensor): The attention mask tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
"""
if block_size is not None:
if int(block_size) != block_size or block_size <= 0:
raise ValueError(f"block_size must be a positive integer, got {block_size}.")
block_size = int(block_size)

attention_bias = attention_bias.detach()
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
topk_values, topk_indices = torch.topk(
attention_bias.to(torch.float),
attention_bias,
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .to(torch.float) conversion was removed from the torch.topk call. While this may be intentional to preserve the original dtype, it changes the existing behavior. If attention_bias is not already torch.float, this could affect numerical precision in the topk operation. Verify this is the intended behavior or document why the dtype conversion was removed.

Suggested change
attention_bias,
attention_bias.to(torch.float),

Copilot uses AI. Check for mistakes.
window_size, dim=-1, largest=True, sorted=False
)
attention_mask = torch.zeros_like(
Expand All @@ -58,22 +84,11 @@ def topk_mask(

if block_size is not None and block_size > 1:
key_len = attention_mask.shape[-1]
full_len = (key_len // block_size) * block_size

if full_len:
block_view = attention_mask[..., :full_len]
block_shape = (*block_view.shape[:-1], full_len // block_size, block_size)
blocks = block_view.view(*block_shape)
block_counts = blocks.sum(dim=-1).to(torch.int32)
block_keep = (block_counts * 2) > block_size
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))

if key_len > full_len:
tail_slice = attention_mask[..., full_len:]
tail_len = tail_slice.shape[-1]
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
tail_keep = (tail_counts * 2) > tail_len
tail_slice.copy_(tail_keep.expand_as(tail_slice))
attention_mask = block_smooth(
attention_mask=attention_mask,
key_len=key_len,
block_size=block_size
)

return attention_mask

Expand Down Expand Up @@ -101,33 +116,18 @@ def relu_mask(
attention_mask (Tensor): The attention mask tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
"""
if block_size is not None:
if int(block_size) != block_size or block_size <= 0:
raise ValueError(f"block_size must be a positive integer, got {block_size}.")
block_size = int(block_size)


attention_bias = attention_bias.detach()
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
attention_mask = attention_bias > 0

if block_size is not None and block_size > 1:
key_len = attention_mask.shape[-1]
full_len = (key_len // block_size) * block_size

if full_len:
block_view = attention_mask[..., :full_len]
block_shape = (*block_view.shape[:-1], full_len // block_size, block_size)
blocks = block_view.view(*block_shape)
block_counts = blocks.sum(dim=-1).to(torch.int32)
block_keep = (block_counts * 2) > block_size
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))

if key_len > full_len:
tail_slice = attention_mask[..., full_len:]
tail_len = tail_slice.shape[-1]
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
tail_keep = (tail_counts * 2) > tail_len
tail_slice.copy_(tail_keep.expand_as(tail_slice))
attention_mask = block_smooth(
attention_mask=attention_mask,
key_len=key_len,
block_size=block_size
)

return attention_mask

Expand Down