Skip to content

Commit 510ef4d

Browse files
committed
Refactors FDMA utils and centralizes mask creation
Renames internal attention callsites to FDMA-prefixed names for clarity and consistency. Adds lazy import and wiring for a mask creation utility and uses it to build sliding‑window masks instead of ad‑hoc top‑k logic, improving readability and numerical correctness by using attention bias dtype for min. Removes local pad/unpad fallbacks in favor of package implementations. Updates lazy loader return signature and processing hook accordingly.
1 parent a06bff1 commit 510ef4d

File tree

1 file changed

+26
-84
lines changed

1 file changed

+26
-84
lines changed

flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py

Lines changed: 26 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@
2727

2828

2929
# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves
30-
_flash_fn = None
31-
_flash_varlen_fn = None
30+
_fdma_fn = None
31+
_fdma_varlen_fn = None
3232
_pad_fn = None
3333
_unpad_fn = None
34+
_create_mask_fn = None
3435

3536
# function that processes kwargs, generalized to handle any supported kwarg within the function
3637
_process_flash_kwargs_fn = None
@@ -53,13 +54,12 @@ def _lazy_imports(implementation: Optional[str]):
5354
"""
5455
is_fdma = is_flash_dmattn_available()
5556

56-
pad_input, unpad_input = _pad_input, _unpad_input
57-
5857
if (implementation == "flash_dmattn" and is_fdma) or (implementation is None and is_fdma):
5958
from flash_dmattn import flash_dmattn_func, flash_dmattn_varlen_func
6059
from flash_dmattn.utils.padding import pad_input, unpad_input
60+
from flash_dmattn.utils.mask import create_mask
6161

62-
return flash_dmattn_func, flash_dmattn_varlen_func, pad_input, unpad_input
62+
return flash_dmattn_func, flash_dmattn_varlen_func, pad_input, unpad_input, create_mask
6363

6464

6565
def _lazy_define_process_function(flash_function):
@@ -90,15 +90,15 @@ def lazy_import_flash_dynamic_mask_attention(implementation: Optional[str], forc
9090
NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can
9191
work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`.
9292
"""
93-
global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn
94-
if force_import or any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):
95-
_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation)
96-
93+
global _fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn
94+
if force_import or any(k is None for k in [_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn]):
95+
_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn = _lazy_imports(implementation)
96+
9797
global _process_flash_kwargs_fn
9898
if force_import or _process_flash_kwargs_fn is None:
99-
_process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn)
99+
_process_flash_kwargs_fn = _lazy_define_process_function(_fdma_varlen_fn)
100100

101-
return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn
101+
return (_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn), _process_flash_kwargs_fn
102102

103103

104104
def _index_first_axis(tensor, indices):
@@ -113,57 +113,6 @@ def _index_first_axis(tensor, indices):
113113
return reshaped_tensor[indices]
114114

115115

116-
def _unpad_input(hidden_states, attention_mask, unused_mask=None):
117-
"""
118-
unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
119-
120-
Arguments:
121-
hidden_states: (batch, seqlen, ...)
122-
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
123-
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
124-
125-
Return:
126-
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
127-
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
128-
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
129-
max_seqlen_in_batch: int
130-
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
131-
"""
132-
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
133-
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
134-
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
135-
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
136-
max_seqlen_in_batch = seqlens_in_batch.max().item()
137-
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
138-
139-
return (
140-
_index_first_axis(hidden_states, indices),
141-
indices,
142-
cu_seqlens,
143-
max_seqlen_in_batch,
144-
used_seqlens_in_batch,
145-
)
146-
147-
148-
def _pad_input(hidden_states, indices, batch, seqlen):
149-
"""
150-
pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
151-
152-
Arguments:
153-
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
154-
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
155-
batch: int, batch size for the padded sequence.
156-
seqlen: int, maximum sequence length for the padded sequence.
157-
158-
Return:
159-
hidden_states: (batch, seqlen, ...)
160-
"""
161-
dim = hidden_states.shape[1:]
162-
output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
163-
output[indices] = hidden_states
164-
return output.view(batch, seqlen, *dim)
165-
166-
167116
def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
168117
"""
169118
Retrieves indexing data required to repad unpadded (ragged) tensors.
@@ -527,7 +476,7 @@ def _flash_dynamic_mask_attention_forward(
527476
"If shape of attention_mask is (batch_size, seq_len), attention_bias has to be None."
528477
)
529478

530-
(flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_dynamic_mask_attention(implementation)
479+
(fdma_fn, fdma_varlen_fn, pad_fn, unpad_fn, create_mask_fn), process_flash_kwargs_fn = lazy_import_flash_dynamic_mask_attention(implementation)
531480

532481
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
533482
query_states, key_states, value_states, attention_bias = fdma_peft_integration_check(
@@ -546,10 +495,10 @@ def _flash_dynamic_mask_attention_forward(
546495
**kwargs,
547496
)
548497

549-
# We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases:
498+
# We will use `fdma_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases:
550499
# Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`.
551500
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
552-
# use `flash_varlen_fn` knowing we already have all necessary the kwargs.
501+
# use `fdma_varlen_fn` knowing we already have all necessary the kwargs.
553502
#
554503
# NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model.
555504
# See #39121 for more information.
@@ -569,7 +518,7 @@ def _flash_dynamic_mask_attention_forward(
569518
if "mps" in str(q.device):
570519
cu_seq_lens_k = cu_seq_lens_k.clone()
571520

572-
out_unpad = flash_varlen_fn(
521+
out_unpad = fdma_varlen_fn(
573522
q,
574523
k,
575524
v,
@@ -600,7 +549,7 @@ def _flash_dynamic_mask_attention_forward(
600549
if "mps" in str(q.device):
601550
cu_seq_lens_k = cu_seq_lens_k.clone()
602551

603-
out = flash_varlen_fn(
552+
out = fdma_varlen_fn(
604553
q,
605554
k,
606555
v,
@@ -624,24 +573,17 @@ def _flash_dynamic_mask_attention_forward(
624573
and window_size is not None
625574
and key_length > window_size
626575
):
627-
min_dtype = torch.finfo(query_states.dtype).min
628-
if attention_mask is not None:
629-
topk_values, topk_indices = torch.topk(
630-
attention_bias.masked_fill(~attention_mask, min_dtype).detach(),
631-
window_size, dim=-1, largest=True, sorted=False
632-
)
633-
attention_mask = torch.zeros_like(
634-
attention_bias, dtype=torch.bool, device=attention_bias.device
635-
).scatter_(-1, topk_indices, topk_values != min_dtype)
636-
else:
637-
topk_values, topk_indices = torch.topk(
638-
attention_bias.detach(), window_size, dim=-1, largest=True, sorted=False
639-
)
640-
attention_mask = torch.zeros_like(
641-
attention_bias, dtype=torch.bool, device=attention_bias.device
642-
).scatter_(-1, topk_indices, topk_values != min_dtype)
576+
attention_mask = create_mask_fn(
577+
attention_bias,
578+
attention_mask,
579+
batch_size=query_states.size(0),
580+
query_len=query_length,
581+
key_len=key_length,
582+
window_size=window_size,
583+
min_dtype=torch.finfo(attention_bias.dtype).min,
584+
)
643585

644-
out = flash_fn(
586+
out = fdma_fn(
645587
query_states,
646588
key_states,
647589
value_states,

0 commit comments

Comments
 (0)