Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 79 additions & 10 deletions flash_dmattn/utils/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import torch


def dynamic_mask(
def topk_mask(
attention_bias: torch.Tensor,
attention_mask: Optional[torch.Tensor],
window_size: int,
min_dtype: float,
block_size: Optional[int] = None,
**kwargs,
):
r"""
This function generates a dynamic mask based on the top-k attention bias.
Expand All @@ -45,10 +46,10 @@ def dynamic_mask(
if int(block_size) != block_size or block_size <= 0:
raise ValueError(f"block_size must be a positive integer, got {block_size}.")
block_size = int(block_size)

attention_bias = attention_bias.detach()
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
topk_values, topk_indices = torch.topk(
attention_bias.detach(),
attention_bias.to(torch.float),
window_size, dim=-1, largest=True, sorted=False
)
attention_mask = torch.zeros_like(
Expand Down Expand Up @@ -77,6 +78,61 @@ def dynamic_mask(
return attention_mask


def relu_mask(
attention_bias: torch.Tensor,
attention_mask: Optional[torch.Tensor],
min_dtype: float,
block_size: Optional[int] = None,
**kwargs
):
r"""
This function generates a dynamic mask based on the ReLU of attention bias.

Args:
attention_bias (torch.Tensor): The attention bias tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
min_dtype (float): The minimum value to use for masking.
block_size (Optional[int]): Optional size of aggregation blocks to smooth the
resulting mask along the key dimension.

Returns:
attention_mask (Tensor): The attention mask tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
"""
if block_size is not None:
if int(block_size) != block_size or block_size <= 0:
raise ValueError(f"block_size must be a positive integer, got {block_size}.")
block_size = int(block_size)

attention_bias = attention_bias.detach()
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
attention_mask = attention_bias > 0

if block_size is not None and block_size > 1:
key_len = attention_mask.shape[-1]
full_len = (key_len // block_size) * block_size

if full_len:
block_view = attention_mask[..., :full_len]
block_shape = (*block_view.shape[:-1], full_len // block_size, block_size)
blocks = block_view.view(*block_shape)
block_counts = blocks.sum(dim=-1).to(torch.int32)
block_keep = (block_counts * 2) > block_size
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))

if key_len > full_len:
tail_slice = attention_mask[..., full_len:]
tail_len = tail_slice.shape[-1]
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
tail_keep = (tail_counts * 2) > tail_len
tail_slice.copy_(tail_keep.expand_as(tail_slice))

return attention_mask



def create_mask(
attention_bias: torch.Tensor,
attention_mask: Optional[torch.Tensor],
Expand All @@ -86,6 +142,7 @@ def create_mask(
window_size: int,
min_dtype: float,
block_size: Optional[int] = None,
type: str = "topk",
) -> torch.Tensor:
r"""
This function creates a mask tensor for Flash Dynamic Mask Attention.
Expand All @@ -103,6 +160,7 @@ def create_mask(
window_size (int): The number of top elements to consider for the attention mask.
min_dtype (float): The minimum value to use for masking.
block_size (Optional[int]): Optional size of aggregation blocks after top-k masking.
type (str): The type of mask to create. Options are "topk" and "relu".

Returns:
attention (Tensor): The attention mask tensor of shape
Expand Down Expand Up @@ -133,12 +191,23 @@ def create_mask(
)

# Generate dynamic mask based on attention_bias and attention_mask
attention_mask = dynamic_mask(
attention_bias,
attention_mask,
window_size,
min_dtype,
block_size=block_size,
)
if type == "topk":
attention_mask = topk_mask(
attention_bias=attention_bias,
attention_mask=attention_mask,
window_size=window_size,
min_dtype=min_dtype,
block_size=block_size,
)
elif type == "relu":
attention_mask = relu_mask(
attention_bias=attention_bias,
attention_mask=attention_mask,
window_size=window_size,
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The relu_mask function does not accept a window_size parameter, but it's being passed here. This will cause an error when type == 'relu' is used. The window_size parameter should be removed from this function call, or the relu_mask function signature should be updated to accept and use it.

Suggested change
window_size=window_size,

Copilot uses AI. Check for mistakes.
min_dtype=min_dtype,
block_size=block_size,
)
else:
raise ValueError(f"Unsupported mask type: {type}. Supported types are 'topk' and 'relu'.")

return attention_mask
Loading