Skip to content

Commit 2ce1efe

Browse files
authored
Merge pull request #197 from SmallDoges/fix-195
[FEATURE SUPPORT] Centralize dynamic mask creation for FDMA
2 parents df5ade9 + 0dbd673 commit 2ce1efe

File tree

3 files changed

+135
-85
lines changed

3 files changed

+135
-85
lines changed

flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py

Lines changed: 26 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@
2727

2828

2929
# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves
30-
_flash_fn = None
31-
_flash_varlen_fn = None
30+
_fdma_fn = None
31+
_fdma_varlen_fn = None
3232
_pad_fn = None
3333
_unpad_fn = None
34+
_create_mask_fn = None
3435

3536
# function that processes kwargs, generalized to handle any supported kwarg within the function
3637
_process_flash_kwargs_fn = None
@@ -53,13 +54,12 @@ def _lazy_imports(implementation: Optional[str]):
5354
"""
5455
is_fdma = is_flash_dmattn_available()
5556

56-
pad_input, unpad_input = _pad_input, _unpad_input
57-
5857
if (implementation == "flash_dmattn" and is_fdma) or (implementation is None and is_fdma):
5958
from flash_dmattn import flash_dmattn_func, flash_dmattn_varlen_func
6059
from flash_dmattn.utils.padding import pad_input, unpad_input
60+
from flash_dmattn.utils.mask import create_mask
6161

62-
return flash_dmattn_func, flash_dmattn_varlen_func, pad_input, unpad_input
62+
return flash_dmattn_func, flash_dmattn_varlen_func, pad_input, unpad_input, create_mask
6363

6464

6565
def _lazy_define_process_function(flash_function):
@@ -90,15 +90,15 @@ def lazy_import_flash_dynamic_mask_attention(implementation: Optional[str], forc
9090
NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can
9191
work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`.
9292
"""
93-
global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn
94-
if force_import or any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):
95-
_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation)
96-
93+
global _fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn
94+
if force_import or any(k is None for k in [_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn]):
95+
_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn = _lazy_imports(implementation)
96+
9797
global _process_flash_kwargs_fn
9898
if force_import or _process_flash_kwargs_fn is None:
99-
_process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn)
99+
_process_flash_kwargs_fn = _lazy_define_process_function(_fdma_varlen_fn)
100100

101-
return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn
101+
return (_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn), _process_flash_kwargs_fn
102102

103103

104104
def _index_first_axis(tensor, indices):
@@ -113,57 +113,6 @@ def _index_first_axis(tensor, indices):
113113
return reshaped_tensor[indices]
114114

115115

116-
def _unpad_input(hidden_states, attention_mask, unused_mask=None):
117-
"""
118-
unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
119-
120-
Arguments:
121-
hidden_states: (batch, seqlen, ...)
122-
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
123-
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
124-
125-
Return:
126-
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
127-
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
128-
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
129-
max_seqlen_in_batch: int
130-
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
131-
"""
132-
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
133-
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
134-
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
135-
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
136-
max_seqlen_in_batch = seqlens_in_batch.max().item()
137-
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
138-
139-
return (
140-
_index_first_axis(hidden_states, indices),
141-
indices,
142-
cu_seqlens,
143-
max_seqlen_in_batch,
144-
used_seqlens_in_batch,
145-
)
146-
147-
148-
def _pad_input(hidden_states, indices, batch, seqlen):
149-
"""
150-
pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
151-
152-
Arguments:
153-
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
154-
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
155-
batch: int, batch size for the padded sequence.
156-
seqlen: int, maximum sequence length for the padded sequence.
157-
158-
Return:
159-
hidden_states: (batch, seqlen, ...)
160-
"""
161-
dim = hidden_states.shape[1:]
162-
output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
163-
output[indices] = hidden_states
164-
return output.view(batch, seqlen, *dim)
165-
166-
167116
def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
168117
"""
169118
Retrieves indexing data required to repad unpadded (ragged) tensors.
@@ -527,7 +476,7 @@ def _flash_dynamic_mask_attention_forward(
527476
"If shape of attention_mask is (batch_size, seq_len), attention_bias has to be None."
528477
)
529478

530-
(flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_dynamic_mask_attention(implementation)
479+
(fdma_fn, fdma_varlen_fn, pad_fn, unpad_fn, create_mask_fn), process_flash_kwargs_fn = lazy_import_flash_dynamic_mask_attention(implementation)
531480

532481
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
533482
query_states, key_states, value_states, attention_bias = fdma_peft_integration_check(
@@ -546,10 +495,10 @@ def _flash_dynamic_mask_attention_forward(
546495
**kwargs,
547496
)
548497

549-
# We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases:
498+
# We will use `fdma_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases:
550499
# Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`.
551500
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
552-
# use `flash_varlen_fn` knowing we already have all necessary the kwargs.
501+
# use `fdma_varlen_fn` knowing we already have all necessary the kwargs.
553502
#
554503
# NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model.
555504
# See #39121 for more information.
@@ -569,7 +518,7 @@ def _flash_dynamic_mask_attention_forward(
569518
if "mps" in str(q.device):
570519
cu_seq_lens_k = cu_seq_lens_k.clone()
571520

572-
out_unpad = flash_varlen_fn(
521+
out_unpad = fdma_varlen_fn(
573522
q,
574523
k,
575524
v,
@@ -600,7 +549,7 @@ def _flash_dynamic_mask_attention_forward(
600549
if "mps" in str(q.device):
601550
cu_seq_lens_k = cu_seq_lens_k.clone()
602551

603-
out = flash_varlen_fn(
552+
out = fdma_varlen_fn(
604553
q,
605554
k,
606555
v,
@@ -624,24 +573,17 @@ def _flash_dynamic_mask_attention_forward(
624573
and window_size is not None
625574
and key_length > window_size
626575
):
627-
min_dtype = torch.finfo(query_states.dtype).min
628-
if attention_mask is not None:
629-
topk_values, topk_indices = torch.topk(
630-
attention_bias.masked_fill(~attention_mask, min_dtype).detach(),
631-
window_size, dim=-1, largest=True, sorted=False
632-
)
633-
attention_mask = torch.zeros_like(
634-
attention_bias, dtype=torch.bool, device=attention_bias.device
635-
).scatter_(-1, topk_indices, topk_values != min_dtype)
636-
else:
637-
topk_values, topk_indices = torch.topk(
638-
attention_bias.detach(), window_size, dim=-1, largest=True, sorted=False
639-
)
640-
attention_mask = torch.zeros_like(
641-
attention_bias, dtype=torch.bool, device=attention_bias.device
642-
).scatter_(-1, topk_indices, topk_values != min_dtype)
576+
attention_mask = create_mask_fn(
577+
attention_bias,
578+
attention_mask,
579+
batch_size=query_states.size(0),
580+
query_len=query_length,
581+
key_len=key_length,
582+
window_size=window_size,
583+
min_dtype=torch.finfo(attention_bias.dtype).min,
584+
)
643585

644-
out = flash_fn(
586+
out = fdma_fn(
645587
query_states,
646588
key_states,
647589
value_states,

flash_dmattn/utils/mask.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2025 Jingze Shi and Liangdong Wang. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import torch
18+
19+
20+
def dynamic_mask(
21+
attention_bias: torch.Tensor,
22+
attention_mask: Optional[torch.Tensor],
23+
window_size: int,
24+
min_dtype: float,
25+
):
26+
r"""
27+
This function generates a dynamic mask based on the top-k attention bias.
28+
29+
Args:
30+
attention_bias (torch.Tensor): The attention bias tensor of shape
31+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
32+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
33+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
34+
window_size (int): The number of top elements to consider for the mask.
35+
min_dtype (float): The minimum value to use for masking.
36+
37+
Returns:
38+
attention_mask (Tensor): The attention mask tensor of shape
39+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
40+
"""
41+
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
42+
topk_values, topk_indices = torch.topk(
43+
attention_bias.detach(),
44+
window_size, dim=-1, largest=True, sorted=False
45+
)
46+
attention_mask = torch.zeros_like(
47+
attention_bias, dtype=torch.bool, device=attention_bias.device
48+
).scatter_(-1, topk_indices, topk_values != min_dtype)
49+
return attention_mask
50+
51+
52+
def create_mask(
53+
attention_bias: torch.Tensor,
54+
attention_mask: Optional[torch.Tensor],
55+
batch_size: int,
56+
query_len: int,
57+
key_len: int,
58+
window_size: int,
59+
min_dtype: float,
60+
) -> torch.Tensor:
61+
r"""
62+
This function creates a mask tensor for Flash Dynamic Mask Attention.
63+
64+
If attention_mask is not of shape (batch_size, seq_len), it needs to match the shape of attention_bias.
65+
66+
Args:
67+
attention_bias (torch.Tensor): The attention bias tensor of shape
68+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
69+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
70+
(batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
71+
batch_size (int): The batch size.
72+
query_len (int): The sequence length of the query.
73+
key_len (int): The sequence length of the key.
74+
window_size (int): The number of top elements to consider for the attention mask.
75+
min_dtype (float): The minimum value to use for masking.
76+
77+
Returns:
78+
attention (Tensor): The attention mask tensor of shape
79+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
80+
"""
81+
82+
# If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len)
83+
if attention_mask is not None and attention_mask.dim() == 2:
84+
if attention_mask.shape[-1] == key_len:
85+
attention_mask = attention_mask.view(batch_size, 1, 1, key_len)
86+
elif attention_mask.shape[-1] == query_len:
87+
pad_len = key_len - query_len
88+
if pad_len > 0:
89+
pad_mask = torch.ones(
90+
(batch_size, 1, 1, pad_len),
91+
dtype=torch.bool,
92+
device=attention_mask.device,
93+
)
94+
attention_mask = torch.cat(
95+
[pad_mask, attention_mask.view(batch_size, 1, 1, query_len)],
96+
dim=-1,
97+
)
98+
else:
99+
attention_mask = attention_mask.view(batch_size, 1, 1, query_len)
100+
else:
101+
raise ValueError(
102+
f"attention_mask shape {attention_mask.shape} is not compatible with key_len {key_len} or query_len {query_len}."
103+
)
104+
105+
# Generate dynamic mask based on attention_bias and attention_mask
106+
attention_mask = dynamic_mask(attention_bias, attention_mask, window_size, min_dtype)
107+
108+
return attention_mask

flash_dmattn/utils/padding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,4 @@ def upad_input(
167167
indices_q,
168168
(cu_seqlens_q, cu_seqlens_k),
169169
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
170-
)
170+
)

0 commit comments

Comments
 (0)