1- # Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
1+ # Copyright 2025 Jingze Shi and Liangdong Wang and the HuggingFace Inc. team. All rights reserved.
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1717from .import_utils import is_flash_dmattn_available
1818
1919from transformers .utils import logging
20- from transformers .integrations import flash_attention
2120
2221
2322logger = logging .get_logger (__name__ )
2625def fdma_peft_integration_check (q , k , v , bias , target_dtype : Optional [torch .dtype ] = None ):
2726 if target_dtype and q .dtype == torch .float32 :
2827 logger .warning_once (f"Casting fp32 inputs back to { target_dtype } for flash-dmattn compatibility." )
29- q , k , v , bias = q .to (target_dtype ), k .to (target_dtype ), v .to (target_dtype ), bias .to (target_dtype )
28+ q = q .to (target_dtype ) if q is not None else None
29+ k = k .to (target_dtype ) if k is not None else None
30+ v = v .to (target_dtype ) if v is not None else None
31+ bias = bias .to (target_dtype ) if bias is not None else None
3032 return q , k , v , bias
3133
3234
@@ -66,7 +68,6 @@ def _flash_dynamic_mask_attention_forward(
6668):
6769 dtype = query_states .dtype
6870 min_dtype = torch .finfo (dtype ).min
69- batch_size , _ , num_kv_heads , _ = key_states .shape
7071
7172 if not all (k in globals () for k in ("_flash_fn" )):
7273 flash_fn = _lazy_imports (implementation )
@@ -85,22 +86,34 @@ def _flash_dynamic_mask_attention_forward(
8586 query_states , key_states , value_states , attention_bias , target_dtype
8687 )
8788
88- if attention_mask is not None and attention_mask .dim () == 4 :
89- if attention_bias .dim () == 3 :
90- attention_bias = attention_bias .unsqueeze (- 2 )
91- attention_bias = attention_bias .masked_fill (
92- ~ attention_mask ,
93- min_dtype
94- )
95-
96- if keep_window_size is not None and key_length > keep_window_size :
97- topk_values , topk_indices = torch .topk (
98- attention_bias , keep_window_size , dim = - 1 , largest = True , sorted = False
99- )
100- attention_mask = torch .zeros_like (attention_bias , dtype = torch .bool , device = attention_bias .device )
101- attention_mask = attention_mask .scatter (- 1 , topk_indices , topk_values != min_dtype )
102- else :
103- attention_mask = None
89+ if (
90+ attention_bias is not None
91+ and keep_window_size is not None
92+ and key_length > keep_window_size
93+ ):
94+ if attention_mask is not None :
95+ if attention_mask .dim () == 4 and attention_bias .dim () == 3 :
96+ attention_bias_for_topk = attention_bias .unsqueeze (- 2 ).expand_as (attention_mask )
97+ else :
98+ attention_bias_for_topk = attention_bias
99+
100+ topk_indices = torch .topk (
101+ attention_bias_for_topk .masked_fill (~ attention_mask , min_dtype ).detach (),
102+ keep_window_size ,
103+ dim = - 1 , largest = True , sorted = False ,
104+ ).indices
105+ attention_mask = torch .zeros_like (attention_bias_for_topk , dtype = torch .bool ).scatter_ (
106+ - 1 , topk_indices , True
107+ ) & attention_mask
108+ else :
109+ topk_indices = torch .topk (
110+ attention_bias .detach (),
111+ keep_window_size ,
112+ dim = - 1 , largest = True , sorted = False ,
113+ ).indices
114+ attention_mask = torch .zeros_like (attention_bias , dtype = torch .bool ).scatter_ (
115+ - 1 , topk_indices , True
116+ )
104117
105118 out = flash_fn (
106119 query_states , key_states , value_states , attn_mask = attention_mask , attn_bias = attention_bias , scale = softmax_scale , is_causal = is_causal
0 commit comments