@@ -26,7 +26,6 @@ def flash_attention_forward(
2626 key_states ,
2727 value_states ,
2828 attention_mask ,
29- query_length ,
3029 dropout = 0.0 ,
3130 softmax_scale = None ,
3231 is_causal = False ,
@@ -51,18 +50,12 @@ def flash_attention_forward(
5150 The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
5251 is_causal (`bool`, *optional*):
5352 """
54-
55- # Flash attention requires the input to have the shape
56- # batch_size x seq_length x head_dim x hidden_dim
57- # therefore we just need to keep the original shape
58- dtype = query_states .dtype
59- query_states = query_states .permute (0 , 2 , 1 , 3 ).contiguous ().to (torch .bfloat16 )
60- key_states = key_states .permute (0 , 2 , 1 , 3 ).contiguous ().to (torch .bfloat16 )
61- value_states = value_states .permute (0 , 2 , 1 , 3 ).contiguous ().to (torch .bfloat16 )
62-
53+ batch_size , query_length = query_states .shape [:2 ]
54+ query_states = query_states .to (torch .bfloat16 )
55+ key_states = key_states .to (torch .bfloat16 )
56+ value_states = value_states .to (torch .bfloat16 )
6357 # Contains at least one padding token in the sequence
6458 if attention_mask is not None :
65- batch_size = query_states .shape [0 ]
6659 (
6760 query_states ,
6861 key_states ,
@@ -98,8 +91,7 @@ def flash_attention_forward(
9891 softmax_scale = softmax_scale ,
9992 causal = is_causal ,
10093 )
101- # re-order the tensor back to (batch, n_heads, seq_length, head_dim)
102- return attn_output .permute (0 , 2 , 1 , 3 ).contiguous ().to (dtype )
94+ return attn_output .reshape (batch_size , query_length , - 1 )
10395
10496
10597# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
@@ -176,10 +168,9 @@ def __init__(self, config, position_embedding_type=None):
176168
177169 self .dropout = nn .Dropout (config .attention_probs_dropout_prob )
178170
179- def transpose_for_scores (self , x : torch .Tensor ) -> torch .Tensor :
171+ def split_heads (self , x : torch .Tensor ) -> torch .Tensor :
180172 new_x_shape = x .size ()[:- 1 ] + (self .num_attention_heads , self .attention_head_size )
181- x = x .view (new_x_shape )
182- return x .permute (0 , 2 , 1 , 3 )
173+ return x .view (new_x_shape )
183174
184175 def forward (
185176 self ,
@@ -204,26 +195,25 @@ def forward(
204195 value_layer = past_key_value [1 ]
205196 attention_mask = encoder_attention_mask
206197 elif is_cross_attention :
207- key_layer = self .transpose_for_scores (self .key (encoder_hidden_states ))
208- value_layer = self .transpose_for_scores (self .value (encoder_hidden_states ))
198+ key_layer = self .split_heads (self .key (encoder_hidden_states ))
199+ value_layer = self .split_heads (self .value (encoder_hidden_states ))
209200 attention_mask = encoder_attention_mask
210201 elif past_key_value is not None :
211- key_layer = self .transpose_for_scores (self .key (hidden_states ))
212- value_layer = self .transpose_for_scores (self .value (hidden_states ))
202+ key_layer = self .split_heads (self .key (hidden_states ))
203+ value_layer = self .split_heads (self .value (hidden_states ))
213204 key_layer = torch .cat ([past_key_value [0 ], key_layer ], dim = 2 )
214205 value_layer = torch .cat ([past_key_value [1 ], value_layer ], dim = 2 )
215206 else :
216- key_layer = self .transpose_for_scores (self .key (hidden_states ))
217- value_layer = self .transpose_for_scores (self .value (hidden_states ))
207+ key_layer = self .split_heads (self .key (hidden_states ))
208+ value_layer = self .split_heads (self .value (hidden_states ))
218209
219- query_layer = self .transpose_for_scores (mixed_query_layer )
210+ query_layer = self .split_heads (mixed_query_layer )
220211 # Flash Attention forward pass
221212 attn_output = flash_attention_forward (
222213 query_layer ,
223214 key_layer ,
224215 value_layer ,
225216 attention_mask ,
226- query_layer .size (- 2 ),
227217 self .dropout .p ,
228218 softmax_scale = None ,
229219 is_causal = False ,
0 commit comments