Skip to content

Commit a06bff1

Browse files
committed
Adds dynamic top‑k attention mask utilities
Introduces utilities to build boolean masks for Flash Dynamic Mask Attention by selecting top‑k positions from an attention bias, improving sparsity and compute efficiency. Handles 2D mask reshaping and padding to align query/key lengths, respects existing masks, and excludes invalid positions via a configurable minimum value.
1 parent 4417b16 commit a06bff1

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

flash_dmattn/utils/mask.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2025 Jingze Shi and Liangdong Wang. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import torch
18+
19+
20+
def dynamic_mask(
21+
attention_bias: torch.Tensor,
22+
attention_mask: Optional[torch.Tensor],
23+
window_size: int,
24+
min_dtype: float,
25+
):
26+
r"""
27+
This function generates a dynamic mask based on the top-k attention bias.
28+
29+
Args:
30+
attention_bias (torch.Tensor): The attention bias tensor of shape
31+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
32+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
33+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
34+
window_size (int): The number of top elements to consider for the mask.
35+
min_dtype (float): The minimum value to use for masking.
36+
37+
Returns:
38+
attention_mask (Tensor): The attention mask tensor of shape
39+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
40+
"""
41+
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
42+
topk_values, topk_indices = torch.topk(
43+
attention_bias.detach(),
44+
window_size, dim=-1, largest=True, sorted=False
45+
)
46+
attention_mask = torch.zeros_like(
47+
attention_bias, dtype=torch.bool, device=attention_bias.device
48+
).scatter_(-1, topk_indices, topk_values != min_dtype)
49+
return attention_mask
50+
51+
52+
def create_mask(
53+
attention_bias: torch.Tensor,
54+
attention_mask: Optional[torch.Tensor],
55+
batch_size: int,
56+
query_len: int,
57+
key_len: int,
58+
window_size: int,
59+
min_dtype: float,
60+
) -> torch.Tensor:
61+
r"""
62+
This function creates a mask tensor for Flash Dynamic Mask Attention.
63+
64+
If attention_mask is not of shape (batch_size, seq_len), it needs to match the shape of attention_bias.
65+
66+
Args:
67+
Args:
68+
attention_bias (torch.Tensor): The attention bias tensor of shape
69+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
70+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
71+
(batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
72+
batch_size (int): The batch size.
73+
query_len (int): The sequence length of the query.
74+
key_len (int): The sequence length of the key.
75+
window_size (int): The number of top elements to consider for the attention mask.
76+
min_dtype (float): The minimum value to use for masking.
77+
78+
Returns:
79+
attention (Tensor): The attention mask tensor of shape
80+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
81+
"""
82+
83+
# If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len)
84+
if attention_mask is not None and attention_mask.dim() == 2:
85+
if attention_mask.shape[-1] == key_len:
86+
attention_mask = attention_mask.view(batch_size, 1, 1, key_len)
87+
elif attention_mask.shape[-1] == query_len:
88+
pad_len = key_len - query_len
89+
if pad_len > 0:
90+
pad_mask = torch.ones(
91+
(batch_size, 1, 1, pad_len),
92+
dtype=torch.bool,
93+
device=attention_mask.device,
94+
)
95+
attention_mask = torch.cat(
96+
[pad_mask, attention_mask.view(batch_size, 1, 1, query_len)],
97+
dim=-1,
98+
)
99+
else:
100+
attention_mask = attention_mask.view(batch_size, 1, 1, query_len)
101+
else:
102+
raise ValueError(
103+
f"attention_mask shape {attention_mask.shape} is not compatible with key_len {key_len} or query_len {query_len}."
104+
)
105+
106+
attention_mask = dynamic_mask(attention_bias, attention_mask, window_size, min_dtype)
107+
108+
return attention_mask

0 commit comments

Comments
 (0)