-
Notifications
You must be signed in to change notification settings - Fork 53
Expand file tree
/
Copy pathmask.py
More file actions
213 lines (185 loc) · 8.76 KB
/
mask.py
File metadata and controls
213 lines (185 loc) · 8.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# Copyright 2025 Jingze Shi and Liangdong Wang. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
def topk_mask(
attention_bias: torch.Tensor,
attention_mask: Optional[torch.Tensor],
window_size: int,
min_dtype: float,
block_size: Optional[int] = None,
**kwargs,
):
r"""
This function generates a dynamic mask based on the top-k attention bias.
Args:
attention_bias (torch.Tensor): The attention bias tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
window_size (int): The number of top elements to consider for the mask.
min_dtype (float): The minimum value to use for masking.
block_size (Optional[int]): Optional size of aggregation blocks to smooth the
resulting mask along the key dimension.
Returns:
attention_mask (Tensor): The attention mask tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
"""
if block_size is not None:
if int(block_size) != block_size or block_size <= 0:
raise ValueError(f"block_size must be a positive integer, got {block_size}.")
block_size = int(block_size)
attention_bias = attention_bias.detach()
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
topk_values, topk_indices = torch.topk(
attention_bias.to(torch.float),
window_size, dim=-1, largest=True, sorted=False
)
attention_mask = torch.zeros_like(
attention_bias, dtype=torch.bool, device=attention_bias.device
).scatter_(-1, topk_indices, topk_values != min_dtype)
if block_size is not None and block_size > 1:
key_len = attention_mask.shape[-1]
full_len = (key_len // block_size) * block_size
if full_len:
block_view = attention_mask[..., :full_len]
block_shape = (*block_view.shape[:-1], full_len // block_size, block_size)
blocks = block_view.view(*block_shape)
block_counts = blocks.sum(dim=-1).to(torch.int32)
block_keep = (block_counts * 2) > block_size
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))
if key_len > full_len:
tail_slice = attention_mask[..., full_len:]
tail_len = tail_slice.shape[-1]
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
tail_keep = (tail_counts * 2) > tail_len
tail_slice.copy_(tail_keep.expand_as(tail_slice))
return attention_mask
def relu_mask(
attention_bias: torch.Tensor,
attention_mask: Optional[torch.Tensor],
min_dtype: float,
block_size: Optional[int] = None,
**kwargs
):
r"""
This function generates a dynamic mask based on the ReLU of attention bias.
Args:
attention_bias (torch.Tensor): The attention bias tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
min_dtype (float): The minimum value to use for masking.
block_size (Optional[int]): Optional size of aggregation blocks to smooth the
resulting mask along the key dimension.
Returns:
attention_mask (Tensor): The attention mask tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
"""
if block_size is not None:
if int(block_size) != block_size or block_size <= 0:
raise ValueError(f"block_size must be a positive integer, got {block_size}.")
block_size = int(block_size)
attention_bias = attention_bias.detach()
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
attention_mask = attention_bias > 0
if block_size is not None and block_size > 1:
key_len = attention_mask.shape[-1]
full_len = (key_len // block_size) * block_size
if full_len:
block_view = attention_mask[..., :full_len]
block_shape = (*block_view.shape[:-1], full_len // block_size, block_size)
blocks = block_view.view(*block_shape)
block_counts = blocks.sum(dim=-1).to(torch.int32)
block_keep = (block_counts * 2) > block_size
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))
if key_len > full_len:
tail_slice = attention_mask[..., full_len:]
tail_len = tail_slice.shape[-1]
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
tail_keep = (tail_counts * 2) > tail_len
tail_slice.copy_(tail_keep.expand_as(tail_slice))
return attention_mask
def create_mask(
attention_bias: torch.Tensor,
attention_mask: Optional[torch.Tensor],
batch_size: int,
query_len: int,
key_len: int,
window_size: int,
min_dtype: float,
block_size: Optional[int] = None,
type: str = "topk",
) -> torch.Tensor:
r"""
This function creates a mask tensor for Flash Dynamic Mask Attention.
If attention_mask is not of shape (batch_size, seq_len), it needs to match the shape of attention_bias.
Args:
attention_bias (torch.Tensor): The attention bias tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
(batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
batch_size (int): The batch size.
query_len (int): The sequence length of the query.
key_len (int): The sequence length of the key.
window_size (int): The number of top elements to consider for the attention mask.
min_dtype (float): The minimum value to use for masking.
block_size (Optional[int]): Optional size of aggregation blocks after top-k masking.
type (str): The type of mask to create. Options are "topk" and "relu".
Returns:
attention (Tensor): The attention mask tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
"""
# If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len)
if attention_mask is not None and attention_mask.dim() == 2:
if attention_mask.shape[-1] == key_len:
attention_mask = attention_mask.view(batch_size, 1, 1, key_len)
elif attention_mask.shape[-1] == query_len:
pad_len = key_len - query_len
if pad_len > 0:
pad_mask = torch.ones(
(batch_size, 1, 1, pad_len),
dtype=torch.bool,
device=attention_mask.device,
)
attention_mask = torch.cat(
[pad_mask, attention_mask.view(batch_size, 1, 1, query_len)],
dim=-1,
)
else:
attention_mask = attention_mask.view(batch_size, 1, 1, query_len)
else:
raise ValueError(
f"attention_mask shape {attention_mask.shape} is not compatible with key_len {key_len} or query_len {query_len}."
)
# Generate dynamic mask based on attention_bias and attention_mask
if type == "topk":
attention_mask = topk_mask(
attention_bias=attention_bias,
attention_mask=attention_mask,
window_size=window_size,
min_dtype=min_dtype,
block_size=block_size,
)
elif type == "relu":
attention_mask = relu_mask(
attention_bias=attention_bias,
attention_mask=attention_mask,
window_size=window_size,
min_dtype=min_dtype,
block_size=block_size,
)
else:
raise ValueError(f"Unsupported mask type: {type}. Supported types are 'topk' and 'relu'.")
return attention_mask