Skip to content

Commit e3ff84c

Browse files
committed
Refactor fdma_peft_integration_check and _flash_dynamic_mask_attention_forward for clarity and consistency; rename keep_window_size to window_size and enhance FlashDynamicMaskAttentionKwargs documentation.
1 parent 32074fa commit e3ff84c

File tree

1 file changed

+37
-11
lines changed

1 file changed

+37
-11
lines changed

flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,29 @@
1717
from .import_utils import is_flash_dmattn_available
1818

1919
from transformers.utils import logging
20-
from transformers.integrations import flash_attention
2120

2221

2322
logger = logging.get_logger(__name__)
2423

2524

26-
def fdma_peft_integration_check(q, k, v, bias, target_dtype: Optional[torch.dtype] = None):
25+
def fdma_peft_integration_check(
26+
q: torch.Tensor,
27+
k: torch.Tensor,
28+
v: torch.Tensor,
29+
bias: Optional[torch.Tensor],
30+
target_dtype: Optional[torch.dtype] = None
31+
):
32+
"""
33+
PEFT usually casts the layer norms in float32 for training stability reasons
34+
therefore the input hidden states gets silently casted in float32. Hence, we need
35+
cast them back in float16 / bfloat16 just to be sure everything works as expected.
36+
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
37+
"""
2738
if target_dtype and q.dtype == torch.float32:
2839
logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-dmattn compatibility.")
29-
q, k, v, bias = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype), bias.to(target_dtype)
40+
q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype)
41+
if bias is not None:
42+
bias = bias.to(target_dtype)
3043
return q, k, v, bias
3144

3245

@@ -43,8 +56,24 @@ def _lazy_imports(impl: Optional[str]):
4356

4457

4558
class FlashDynamicMaskAttentionKwargs(TypedDict, total=False):
46-
cumulative_seqlens_q: Optional[torch.LongTensor]
47-
cumulative_seqlens_k: Optional[torch.LongTensor]
59+
"""
60+
Keyword arguments for Flash Dynamic Mask Attention with Compile.
61+
62+
Attributes:
63+
cu_seq_lens_q (`torch.LongTensor`, *optional*)
64+
Gets cumulative sequence length for query state.
65+
cu_seq_lens_k (`torch.LongTensor`, *optional*)
66+
Gets cumulative sequence length for key state.
67+
max_length_q (`int`, *optional*):
68+
Maximum sequence length for query state.
69+
max_length_k (`int`, *optional*):
70+
Maximum sequence length for key state.
71+
"""
72+
73+
cu_seq_lens_q: Optional[torch.LongTensor]
74+
cu_seq_lens_k: Optional[torch.LongTensor]
75+
max_length_q: Optional[int]
76+
max_length_k: Optional[int]
4877

4978

5079
def _flash_dynamic_mask_attention_forward(
@@ -58,15 +87,14 @@ def _flash_dynamic_mask_attention_forward(
5887
is_causal: bool,
5988
softmax_scale: Optional[float] = None,
6089
softcap: Optional[float] = None,
61-
keep_window_size: Optional[int] = None,
90+
window_size: Optional[int] = None,
6291
deterministic: Optional[bool] = None,
6392
target_dtype: Optional[torch.dtype] = None,
6493
implementation: Optional[str] = None,
6594
**kwargs,
6695
):
6796
dtype = query_states.dtype
6897
min_dtype = torch.finfo(dtype).min
69-
batch_size, _, num_kv_heads, _ = key_states.shape
7098

7199
if not all(k in globals() for k in ("_flash_fn")):
72100
flash_fn = _lazy_imports(implementation)
@@ -93,14 +121,12 @@ def _flash_dynamic_mask_attention_forward(
93121
min_dtype
94122
)
95123

96-
if keep_window_size is not None and key_length > keep_window_size:
124+
if window_size is not None and key_length > window_size:
97125
topk_values, topk_indices = torch.topk(
98-
attention_bias, keep_window_size, dim=-1, largest=True, sorted=False
126+
attention_bias, window_size, dim=-1, largest=True, sorted=False
99127
)
100128
attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device)
101129
attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype)
102-
else:
103-
attention_mask = None
104130

105131
out = flash_fn(
106132
query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal

0 commit comments

Comments
 (0)