2525import torch
2626from torch import nn
2727
28- from transformers .models .distilbert .modeling_distilbert import MultiHeadSelfAttention , TransformerBlock
28+ from transformers .models .distilbert .modeling_distilbert import (
29+ DistilBertFlashAttention2 ,
30+ DistilBertSdpaAttention ,
31+ MultiHeadSelfAttention ,
32+ TransformerBlock ,
33+ )
34+ from transformers .utils import is_flash_attn_2_available , logging
2935
3036from ...composition import adjust_tensors_for_parallel , adjust_tensors_for_parallel_ , match_attn_matrices_for_parallel
3137from ...utils import prefix_attention_mask
3238from .mixin_distilbert import DistilBertMultiHeadSelfAttentionMixin , DistilBertTransfomerBlockAdaptersMixin
3339
3440
41+ if is_flash_attn_2_available ():
42+ from transformers .modeling_flash_attention_utils import _flash_attention_forward
43+
44+
45+ logger = logging .get_logger (__name__ )
46+
47+
3548class MultiHeadSelfAttentionWithAdapters (DistilBertMultiHeadSelfAttentionMixin , MultiHeadSelfAttention ):
3649 def forward (
3750 self ,
@@ -66,18 +79,20 @@ def shape(x: torch.Tensor) -> torch.Tensor:
6679
6780 def unshape (x : torch .Tensor ) -> torch .Tensor :
6881 """group heads"""
69- return x .transpose (1 , 2 ).contiguous ().view (bs , - 1 , self .n_heads * dim_per_head )
82+ return x .transpose (1 , 2 ).contiguous ().view (x . shape [ 0 ] , - 1 , self .n_heads * dim_per_head )
7083
7184 q = shape (self .q_lin (query )) # (bs, n_heads, q_length, dim_per_head)
7285 k = shape (self .k_lin (key )) # (bs, n_heads, k_length, dim_per_head)
7386 v = shape (self .v_lin (value )) # (bs, n_heads, k_length, dim_per_head)
7487
88+ # >>> START AH Changes <<<
7589 q , k , v = match_attn_matrices_for_parallel (q , k , v )
7690 (mask ,) = adjust_tensors_for_parallel (q , mask )
7791
7892 k , v , mask = self .prefix_tuning (k , v , value , mask , invert_mask = False )
7993 bs = k .size (0 ) # reset for Parallel block
8094 (q ,) = adjust_tensors_for_parallel (k , q )
95+ # >>> END AH Changes <<<
8196
8297 mask_reshp = (bs , 1 , 1 , k .size (2 ))
8398
@@ -105,6 +120,172 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
105120 return (context ,)
106121
107122
123+ class DistilBertSdpaAttentionWithAdapters (DistilBertMultiHeadSelfAttentionMixin , DistilBertSdpaAttention ):
124+ def forward (
125+ self ,
126+ query : torch .Tensor ,
127+ key : torch .Tensor ,
128+ value : torch .Tensor ,
129+ mask : torch .Tensor ,
130+ head_mask : Optional [torch .Tensor ] = None ,
131+ output_attentions : bool = False ,
132+ ) -> Tuple [torch .Tensor , ...]:
133+ """
134+ Parameters:
135+ query: torch.tensor(bs, seq_length, dim)
136+ key: torch.tensor(bs, seq_length, dim)
137+ value: torch.tensor(bs, seq_length, dim)
138+ mask: torch.tensor(bs, seq_length)
139+
140+ Returns:
141+ weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
142+ seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
143+ """
144+ if output_attentions or head_mask is not None :
145+ logger .warning_once (
146+ "DistilBertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
147+ " `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying"
148+ " the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be"
149+ ' removed using the argument `attn_implementation="eager"` when loading the model.'
150+ )
151+ return super ().forward (
152+ query ,
153+ key ,
154+ value ,
155+ mask ,
156+ head_mask ,
157+ output_attentions ,
158+ )
159+
160+ batch_size , _ , _ = query .size ()
161+ dim_per_head = self .dim // self .n_heads
162+
163+ def shape (x : torch .Tensor ) -> torch .Tensor :
164+ """separate heads"""
165+ # keep first dim due to parallel composition
166+ return x .view (x .shape [0 ], - 1 , self .n_heads , dim_per_head ).transpose (1 , 2 )
167+
168+ def unshape (x : torch .Tensor ) -> torch .Tensor :
169+ """group heads"""
170+ return x .transpose (1 , 2 ).contiguous ().view (x .shape [0 ], - 1 , self .n_heads * dim_per_head )
171+
172+ q = shape (self .q_lin (query )) # (bs, n_heads, q_length, dim_per_head)
173+ k = shape (self .k_lin (key )) # (bs, n_heads, k_length, dim_per_head)
174+ v = shape (self .v_lin (value )) # (bs, n_heads, k_length, dim_per_head)
175+
176+ # >>> START AH Changes <<<
177+ q , k , v = match_attn_matrices_for_parallel (q , k , v )
178+ (mask ,) = adjust_tensors_for_parallel (q , mask )
179+
180+ k , v , mask = self .prefix_tuning (k , v , value , mask , invert_mask = False )
181+ (q ,) = adjust_tensors_for_parallel (k , q )
182+ # >>> END AH Changes <<<
183+
184+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
185+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
186+ # Reference: https://github.com/pytorch/pytorch/issues/112577
187+ if self .require_contiguous_qkv and q .device .type == "cuda" and mask is not None :
188+ q = q .contiguous ()
189+ k = k .contiguous ()
190+ v = v .contiguous ()
191+
192+ attn_output = torch .nn .functional .scaled_dot_product_attention (
193+ q ,
194+ k ,
195+ v ,
196+ attn_mask = mask ,
197+ dropout_p = self .dropout_prob if self .training else 0.0 ,
198+ is_causal = False ,
199+ )
200+
201+ attn_output = unshape (attn_output )
202+ attn_output = self .out_lin (attn_output )
203+
204+ return (attn_output ,)
205+
206+
207+ class DistilBertFlashAttention2WithAdapters (DistilBertMultiHeadSelfAttentionMixin , DistilBertFlashAttention2 ):
208+ def forward (
209+ self ,
210+ query : torch .Tensor ,
211+ key : torch .Tensor ,
212+ value : torch .Tensor ,
213+ mask : torch .Tensor ,
214+ head_mask : Optional [torch .Tensor ] = None ,
215+ output_attentions : bool = False ,
216+ ) -> Tuple [torch .Tensor , ...]:
217+ """
218+ Parameters:
219+ query: torch.tensor(bs, seq_length, dim)
220+ key: torch.tensor(bs, seq_length, dim)
221+ value: torch.tensor(bs, seq_length, dim)
222+ mask: torch.tensor(bs, seq_length)
223+
224+ Returns:
225+ weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
226+ seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
227+ """
228+ batch_size , q_length , dim = query .size ()
229+
230+ dim_per_head = self .dim // self .n_heads
231+
232+ def reshape (x : torch .Tensor ) -> torch .Tensor :
233+ """separate heads"""
234+ return x .view (x .shape [0 ], - 1 , self .n_heads , dim_per_head )
235+
236+ # Flash attention requires the input to have the shape
237+ # batch_size x seq_length x head_dim x hidden_dim
238+ query_states = reshape (self .q_lin (query ))
239+ key_states = reshape (self .k_lin (key ))
240+ value_states = reshape (self .v_lin (value ))
241+
242+ attn_dropout = self .config .attention_dropout if self .training else 0.0
243+
244+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
245+ # therefore the input hidden states gets silently casted in float32. Hence, we need
246+ # cast them back in the correct dtype just to be sure everything works as expected.
247+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
248+ # in fp32. (LlamaRMSNorm handles it correctly)
249+
250+ if query_states .dtype == torch .float32 :
251+ if torch .is_autocast_enabled ():
252+ target_dtype = torch .get_autocast_gpu_dtype ()
253+ # Handle the case where the model is quantized
254+ elif hasattr (self .config , "_pre_quantization_dtype" ):
255+ target_dtype = self .config ._pre_quantization_dtype
256+ else :
257+ target_dtype = self .q_lin .weight .dtype
258+
259+ logger .warning_once (
260+ f"The input hidden states seems to be silently casted in float32, this might be related to"
261+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
262+ f" { target_dtype } ."
263+ )
264+
265+ query_states = query_states .to (target_dtype )
266+ key_states = key_states .to (target_dtype )
267+ value_states = value_states .to (target_dtype )
268+
269+ attn_weights = _flash_attention_forward (
270+ query_states ,
271+ key_states ,
272+ value_states ,
273+ mask ,
274+ q_length ,
275+ dropout = attn_dropout ,
276+ use_top_left_mask = self ._flash_attn_uses_top_left_mask ,
277+ is_causal = self .is_causal ,
278+ )
279+
280+ attn_weights_reshaped = attn_weights .reshape (batch_size , q_length , self .n_heads * dim_per_head )
281+ attn_output = self .out_lin (attn_weights_reshaped )
282+
283+ if output_attentions :
284+ return (attn_output , attn_weights )
285+ else :
286+ return (attn_output ,)
287+
288+
108289class TransformerBlockWithAdapters (DistilBertTransfomerBlockAdaptersMixin , TransformerBlock ):
109290 def forward (
110291 self ,
@@ -123,7 +304,7 @@ def forward(
123304 torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.
124305 """
125306 adjust_tensors_for_parallel_ (x , attn_mask )
126- attn_mask = prefix_attention_mask (attn_mask , dim = 1 , prefix_value = 1 ) # type: ignore
307+ attn_mask = prefix_attention_mask (attn_mask , dim = [ 2 , 3 ] , prefix_value = 1 ) # type: ignore
127308
128309 # Self-Attention
129310 sa_output = self .attention (
0 commit comments