Skip to content

Commit bd824a7

Browse files
authored
Merge pull request #218 from flash-algo/last-mask-version
Refactor create_mask function parameters
2 parents d209826 + 507cfe6 commit bd824a7

File tree

2 files changed

+8
-11
lines changed

2 files changed

+8
-11
lines changed

flash_sparse_attn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Optional
44

5-
__version__ = "1.2.3"
5+
__version__ = "1.2.4"
66

77

88
# Import CUDA functions when available

flash_sparse_attn/utils/mask.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,12 @@ def relu_mask(
163163

164164
def create_mask(
165165
attention_bias: torch.Tensor,
166-
attention_mask: Optional[torch.Tensor],
167-
batch_size: int,
168166
query_len: int,
169-
key_len: int,
170-
window_size: Optional[int],
171-
min_dtype: Optional[float],
172-
block_size: Optional[int],
173167
type: str = "topk",
168+
attention_mask: Optional[torch.Tensor] = None,
169+
window_size: Optional[int] = None,
170+
min_dtype: Optional[float] = None,
171+
block_size: Optional[int] = None,
174172
) -> torch.Tensor:
175173
r"""
176174
This function creates a mask tensor for Flash Sparse Attention.
@@ -180,15 +178,13 @@ def create_mask(
180178
Args:
181179
attention_bias (torch.Tensor): The attention bias tensor of shape
182180
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
181+
query_len (int): The sequence length of the query.
182+
type (str): The type of mask to create. Options are "topk" and "relu".
183183
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
184184
(batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
185-
batch_size (int): The batch size.
186-
query_len (int): The sequence length of the query.
187-
key_len (int): The sequence length of the key.
188185
window_size (Optional[int]): The number of top elements to consider for the attention mask.
189186
min_dtype (Optional[float]): The minimum value to use for masking.
190187
block_size (Optional[int]): Optional size of aggregation blocks after top-k masking.
191-
type (str): The type of mask to create. Options are "topk" and "relu".
192188
193189
Returns:
194190
attention (Tensor): The attention mask tensor of shape
@@ -200,6 +196,7 @@ def create_mask(
200196

201197
# If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len)
202198
if attention_mask is not None and attention_mask.dim() == 2:
199+
batch_size, key_len = attention_bias.shape[0], attention_bias.shape[-1]
203200
if attention_mask.shape[-1] == key_len:
204201
attention_mask = attention_mask.view(batch_size, 1, 1, key_len)
205202
elif attention_mask.shape[-1] == query_len:

0 commit comments

Comments
 (0)