Skip to content

Commit d7d87ef

Browse files
committed
Fixes top‑k attention masking and safe casting
Respects existing masks when applying keep‑window top‑k selection and aligns bias/mask shapes (3D↔4D) before selection. Builds boolean masks from indices and avoids bias overwrites. Prevents errors by casting inputs only when present, improving compatibility with PEFT/LoRA setups. Cleans up unused imports and variables for clarity.
1 parent b6d2ea7 commit d7d87ef

File tree

1 file changed

+33
-20
lines changed

1 file changed

+33
-20
lines changed

flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
1+
# Copyright 2025 Jingze Shi and Liangdong Wang and the HuggingFace Inc. team. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@
1717
from .import_utils import is_flash_dmattn_available
1818

1919
from transformers.utils import logging
20-
from transformers.integrations import flash_attention
2120

2221

2322
logger = logging.get_logger(__name__)
@@ -26,7 +25,10 @@
2625
def fdma_peft_integration_check(q, k, v, bias, target_dtype: Optional[torch.dtype] = None):
2726
if target_dtype and q.dtype == torch.float32:
2827
logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-dmattn compatibility.")
29-
q, k, v, bias = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype), bias.to(target_dtype)
28+
q = q.to(target_dtype) if q is not None else None
29+
k = k.to(target_dtype) if k is not None else None
30+
v = v.to(target_dtype) if v is not None else None
31+
bias = bias.to(target_dtype) if bias is not None else None
3032
return q, k, v, bias
3133

3234

@@ -66,7 +68,6 @@ def _flash_dynamic_mask_attention_forward(
6668
):
6769
dtype = query_states.dtype
6870
min_dtype = torch.finfo(dtype).min
69-
batch_size, _, num_kv_heads, _ = key_states.shape
7071

7172
if not all(k in globals() for k in ("_flash_fn")):
7273
flash_fn = _lazy_imports(implementation)
@@ -85,22 +86,34 @@ def _flash_dynamic_mask_attention_forward(
8586
query_states, key_states, value_states, attention_bias, target_dtype
8687
)
8788

88-
if attention_mask is not None and attention_mask.dim() == 4:
89-
if attention_bias.dim() == 3:
90-
attention_bias = attention_bias.unsqueeze(-2)
91-
attention_bias = attention_bias.masked_fill(
92-
~attention_mask,
93-
min_dtype
94-
)
95-
96-
if keep_window_size is not None and key_length > keep_window_size:
97-
topk_values, topk_indices = torch.topk(
98-
attention_bias, keep_window_size, dim=-1, largest=True, sorted=False
99-
)
100-
attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device)
101-
attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype)
102-
else:
103-
attention_mask = None
89+
if (
90+
attention_bias is not None
91+
and keep_window_size is not None
92+
and key_length > keep_window_size
93+
):
94+
if attention_mask is not None:
95+
if attention_mask.dim() == 4 and attention_bias.dim() == 3:
96+
attention_bias_for_topk = attention_bias.unsqueeze(-2).expand_as(attention_mask)
97+
else:
98+
attention_bias_for_topk = attention_bias
99+
100+
topk_indices = torch.topk(
101+
attention_bias_for_topk.masked_fill(~attention_mask, min_dtype).detach(),
102+
keep_window_size,
103+
dim=-1, largest=True, sorted=False,
104+
).indices
105+
attention_mask = torch.zeros_like(attention_bias_for_topk, dtype=torch.bool).scatter_(
106+
-1, topk_indices, True
107+
) & attention_mask
108+
else:
109+
topk_indices = torch.topk(
110+
attention_bias.detach(),
111+
keep_window_size,
112+
dim=-1, largest=True, sorted=False,
113+
).indices
114+
attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool).scatter_(
115+
-1, topk_indices, True
116+
)
104117

105118
out = flash_fn(
106119
query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal

0 commit comments

Comments
 (0)