Skip to content

Commit 9bd0fd6

Browse files
committed
Adds block-wise smoothing to attention mask
Introduces an optional block_size to aggregate top-k selections along the key dimension using a majority vote, reducing fragmentation and encouraging locality in the dynamic mask. Validates block_size as a positive integer, handles remainder tails, and forwards the parameter through mask creation. Updates docs accordingly. Preserves previous behavior when unset and uses in-place ops for efficiency.
1 parent 424b733 commit 9bd0fd6

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

flash_dmattn/utils/mask.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def dynamic_mask(
2222
attention_mask: Optional[torch.Tensor],
2323
window_size: int,
2424
min_dtype: float,
25+
block_size: Optional[int] = None,
2526
):
2627
r"""
2728
This function generates a dynamic mask based on the top-k attention bias.
@@ -33,11 +34,18 @@ def dynamic_mask(
3334
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
3435
window_size (int): The number of top elements to consider for the mask.
3536
min_dtype (float): The minimum value to use for masking.
37+
block_size (Optional[int]): Optional size of aggregation blocks to smooth the
38+
resulting mask along the key dimension.
3639
3740
Returns:
3841
attention_mask (Tensor): The attention mask tensor of shape
3942
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
4043
"""
44+
if block_size is not None:
45+
if int(block_size) != block_size or block_size <= 0:
46+
raise ValueError(f"block_size must be a positive integer, got {block_size}.")
47+
block_size = int(block_size)
48+
4149
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
4250
topk_values, topk_indices = torch.topk(
4351
attention_bias.detach(),
@@ -46,6 +54,26 @@ def dynamic_mask(
4654
attention_mask = torch.zeros_like(
4755
attention_bias, dtype=torch.bool, device=attention_bias.device
4856
).scatter_(-1, topk_indices, topk_values != min_dtype)
57+
58+
if block_size is not None and block_size > 1:
59+
key_len = attention_mask.shape[-1]
60+
full_len = (key_len // block_size) * block_size
61+
62+
if full_len:
63+
block_view = attention_mask[..., :full_len]
64+
block_shape = (*block_view.shape[:-1], full_len // block_size, block_size)
65+
blocks = block_view.view(*block_shape)
66+
block_counts = blocks.sum(dim=-1).to(torch.int32)
67+
block_keep = (block_counts * 2) > block_size
68+
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))
69+
70+
if key_len > full_len:
71+
tail_slice = attention_mask[..., full_len:]
72+
tail_len = tail_slice.shape[-1]
73+
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
74+
tail_keep = (tail_counts * 2) > tail_len
75+
tail_slice.copy_(tail_keep.expand_as(tail_slice))
76+
4977
return attention_mask
5078

5179

@@ -57,6 +85,7 @@ def create_mask(
5785
key_len: int,
5886
window_size: int,
5987
min_dtype: float,
88+
block_size: Optional[int] = None,
6089
) -> torch.Tensor:
6190
r"""
6291
This function creates a mask tensor for Flash Dynamic Mask Attention.
@@ -73,6 +102,7 @@ def create_mask(
73102
key_len (int): The sequence length of the key.
74103
window_size (int): The number of top elements to consider for the attention mask.
75104
min_dtype (float): The minimum value to use for masking.
105+
block_size (Optional[int]): Optional size of aggregation blocks after top-k masking.
76106
77107
Returns:
78108
attention (Tensor): The attention mask tensor of shape
@@ -103,6 +133,12 @@ def create_mask(
103133
)
104134

105135
# Generate dynamic mask based on attention_bias and attention_mask
106-
attention_mask = dynamic_mask(attention_bias, attention_mask, window_size, min_dtype)
136+
attention_mask = dynamic_mask(
137+
attention_bias,
138+
attention_mask,
139+
window_size,
140+
min_dtype,
141+
block_size=block_size,
142+
)
107143

108144
return attention_mask

0 commit comments

Comments
 (0)