1717from .import_utils import is_flash_dmattn_available
1818
1919from transformers .utils import logging
20- from transformers .integrations import flash_attention
2120
2221
2322logger = logging .get_logger (__name__ )
2423
2524
26- def fdma_peft_integration_check (q , k , v , bias , target_dtype : Optional [torch .dtype ] = None ):
25+ def fdma_peft_integration_check (
26+ q : torch .Tensor ,
27+ k : torch .Tensor ,
28+ v : torch .Tensor ,
29+ bias : Optional [torch .Tensor ],
30+ target_dtype : Optional [torch .dtype ] = None
31+ ):
32+ """
33+ PEFT usually casts the layer norms in float32 for training stability reasons
34+ therefore the input hidden states gets silently casted in float32. Hence, we need
35+ cast them back in float16 / bfloat16 just to be sure everything works as expected.
36+ This might slowdown training & inference so it is recommended to not cast the LayerNorms!
37+ """
2738 if target_dtype and q .dtype == torch .float32 :
2839 logger .warning_once (f"Casting fp32 inputs back to { target_dtype } for flash-dmattn compatibility." )
29- q , k , v , bias = q .to (target_dtype ), k .to (target_dtype ), v .to (target_dtype ), bias .to (target_dtype )
40+ q , k , v = q .to (target_dtype ), k .to (target_dtype ), v .to (target_dtype )
41+ if bias is not None :
42+ bias = bias .to (target_dtype )
3043 return q , k , v , bias
3144
3245
@@ -43,8 +56,24 @@ def _lazy_imports(impl: Optional[str]):
4356
4457
4558class FlashDynamicMaskAttentionKwargs (TypedDict , total = False ):
46- cumulative_seqlens_q : Optional [torch .LongTensor ]
47- cumulative_seqlens_k : Optional [torch .LongTensor ]
59+ """
60+ Keyword arguments for Flash Dynamic Mask Attention with Compile.
61+
62+ Attributes:
63+ cu_seq_lens_q (`torch.LongTensor`, *optional*)
64+ Gets cumulative sequence length for query state.
65+ cu_seq_lens_k (`torch.LongTensor`, *optional*)
66+ Gets cumulative sequence length for key state.
67+ max_length_q (`int`, *optional*):
68+ Maximum sequence length for query state.
69+ max_length_k (`int`, *optional*):
70+ Maximum sequence length for key state.
71+ """
72+
73+ cu_seq_lens_q : Optional [torch .LongTensor ]
74+ cu_seq_lens_k : Optional [torch .LongTensor ]
75+ max_length_q : Optional [int ]
76+ max_length_k : Optional [int ]
4877
4978
5079def _flash_dynamic_mask_attention_forward (
@@ -58,15 +87,14 @@ def _flash_dynamic_mask_attention_forward(
5887 is_causal : bool ,
5988 softmax_scale : Optional [float ] = None ,
6089 softcap : Optional [float ] = None ,
61- keep_window_size : Optional [int ] = None ,
90+ window_size : Optional [int ] = None ,
6291 deterministic : Optional [bool ] = None ,
6392 target_dtype : Optional [torch .dtype ] = None ,
6493 implementation : Optional [str ] = None ,
6594 ** kwargs ,
6695):
6796 dtype = query_states .dtype
6897 min_dtype = torch .finfo (dtype ).min
69- batch_size , _ , num_kv_heads , _ = key_states .shape
7098
7199 if not all (k in globals () for k in ("_flash_fn" )):
72100 flash_fn = _lazy_imports (implementation )
@@ -93,14 +121,12 @@ def _flash_dynamic_mask_attention_forward(
93121 min_dtype
94122 )
95123
96- if keep_window_size is not None and key_length > keep_window_size :
124+ if window_size is not None and key_length > window_size :
97125 topk_values , topk_indices = torch .topk (
98- attention_bias , keep_window_size , dim = - 1 , largest = True , sorted = False
126+ attention_bias , window_size , dim = - 1 , largest = True , sorted = False
99127 )
100128 attention_mask = torch .zeros_like (attention_bias , dtype = torch .bool , device = attention_bias .device )
101129 attention_mask = attention_mask .scatter (- 1 , topk_indices , topk_values != min_dtype )
102- else :
103- attention_mask = None
104130
105131 out = flash_fn (
106132 query_states , key_states , value_states , attn_mask = attention_mask , attn_bias = attention_bias , scale = softmax_scale , is_causal = is_causal
0 commit comments