88from transformers .activations import gelu_new
99from transformers .models .bert import modeling_bert
1010from transformers .models .bert .modeling_bert import BertEncoder , BertOnlyMLMHead , BertPooler
11+ from transformers .pytorch_utils import Conv1D
1112from transformers .utils import is_flash_attn_2_available , logging
1213
1314if is_flash_attn_2_available ():
@@ -50,7 +51,7 @@ def flash_attention_forward(
5051 The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
5152 is_causal (`bool`, *optional*):
5253 """
53- batch_size , query_length = query_states .shape [: 2 ]
54+ batch_size , query_length , n_heads , head_dim = query_states .shape
5455 query_states = query_states .to (torch .bfloat16 )
5556 key_states = key_states .to (torch .bfloat16 )
5657 value_states = value_states .to (torch .bfloat16 )
@@ -91,7 +92,7 @@ def flash_attention_forward(
9192 softmax_scale = softmax_scale ,
9293 causal = is_causal ,
9394 )
94- return attn_output .reshape (batch_size , query_length , - 1 )
95+ return attn_output .reshape (batch_size , query_length , n_heads * head_dim )
9596
9697
9798# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
@@ -161,11 +162,10 @@ def __init__(self, config, position_embedding_type=None):
161162 self .num_attention_heads = config .num_attention_heads
162163 self .attention_head_size = int (config .hidden_size / config .num_attention_heads )
163164 self .all_head_size = self .num_attention_heads * self .attention_head_size
164-
165- self .query = nn .Linear (config .hidden_size , self .all_head_size )
166- self .key = nn .Linear (config .hidden_size , self .all_head_size )
167- self .value = nn .Linear (config .hidden_size , self .all_head_size )
168-
165+ self .split_size = config .hidden_size
166+ self .embed_dim = config .hidden_size
167+ self .c_attn = Conv1D (3 * self .embed_dim , self .embed_dim )
168+ self .c_proj = Conv1D (self .embed_dim , self .embed_dim )
169169 self .dropout = nn .Dropout (config .attention_probs_dropout_prob )
170170
171171 def split_heads (self , x : torch .Tensor ) -> torch .Tensor :
@@ -182,8 +182,8 @@ def forward(
182182 past_key_value : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None ,
183183 output_attentions : Optional [bool ] = False ,
184184 ) -> Tuple [torch .Tensor ]:
185- mixed_query_layer = self .query (hidden_states )
186185
186+ query , key , value = self .c_attn (hidden_states ).split (self .split_size , dim = 2 )
187187 # If this is instantiated as a cross-attention module, the keys
188188 # and values come from an encoder; the attention mask needs to be
189189 # such that the encoder's padding tokens are not attended to.
@@ -195,29 +195,32 @@ def forward(
195195 value_layer = past_key_value [1 ]
196196 attention_mask = encoder_attention_mask
197197 elif is_cross_attention :
198- key_layer = self .split_heads (self . key ( encoder_hidden_states ) )
199- value_layer = self .split_heads (self . value ( encoder_hidden_states ) )
198+ key_layer = self .split_heads (key )
199+ value_layer = self .split_heads (value )
200200 attention_mask = encoder_attention_mask
201201 elif past_key_value is not None :
202- key_layer = self .split_heads (self . key ( hidden_states ) )
203- value_layer = self .split_heads (self . value ( hidden_states ) )
202+ key_layer = self .split_heads (key )
203+ value_layer = self .split_heads (value )
204204 key_layer = torch .cat ([past_key_value [0 ], key_layer ], dim = 2 )
205205 value_layer = torch .cat ([past_key_value [1 ], value_layer ], dim = 2 )
206206 else :
207- key_layer = self .split_heads (self . key ( hidden_states ) )
208- value_layer = self .split_heads (self . value ( hidden_states ) )
207+ key_layer = self .split_heads (key )
208+ value_layer = self .split_heads (value )
209209
210- query_layer = self .split_heads (mixed_query_layer )
210+ query_layer = self .split_heads (query )
211+ attn_dropout = self .attn_dropout .p if self .training else 0.0
211212 # Flash Attention forward pass
212213 attn_output = flash_attention_forward (
213214 query_layer ,
214215 key_layer ,
215216 value_layer ,
216217 attention_mask ,
217- self . dropout . p ,
218+ attn_dropout ,
218219 softmax_scale = None ,
219220 is_causal = False ,
220221 )
222+ attn_output = self .c_proj (attn_output )
223+ attn_output = self .dropout (attn_output )
221224 # The BertLayer expects a tuple
222225 return (attn_output ,)
223226
0 commit comments