@@ -51,18 +51,12 @@ def flash_attention_forward(
5151 The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
5252 is_causal (`bool`, *optional*):
5353 """
54-
55- # Flash attention requires the input to have the shape
56- # batch_size x seq_length x head_dim x hidden_dim
57- # therefore we just need to keep the original shape
58- dtype = query_states .dtype
59- query_states = query_states .permute (0 , 2 , 1 , 3 ).contiguous ().to (torch .bfloat16 )
60- key_states = key_states .permute (0 , 2 , 1 , 3 ).contiguous ().to (torch .bfloat16 )
61- value_states = value_states .permute (0 , 2 , 1 , 3 ).contiguous ().to (torch .bfloat16 )
62-
54+ batch_size = query_states .shape [0 ]
55+ query_states = query_states .to (torch .bfloat16 )
56+ key_states = key_states .to (torch .bfloat16 )
57+ value_states = value_states .to (torch .bfloat16 )
6358 # Contains at least one padding token in the sequence
6459 if attention_mask is not None :
65- batch_size = query_states .shape [0 ]
6660 (
6761 query_states ,
6862 key_states ,
@@ -98,8 +92,7 @@ def flash_attention_forward(
9892 softmax_scale = softmax_scale ,
9993 causal = is_causal ,
10094 )
101- # re-order the tensor back to (batch, n_heads, seq_length, head_dim)
102- return attn_output .permute (0 , 2 , 1 , 3 ).contiguous ().to (dtype )
95+ return attn_output .reshape (batch_size , query_length , - 1 )
10396
10497
10598# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
@@ -176,10 +169,9 @@ def __init__(self, config, position_embedding_type=None):
176169
177170 self .dropout = nn .Dropout (config .attention_probs_dropout_prob )
178171
179- def transpose_for_scores (self , x : torch .Tensor ) -> torch .Tensor :
172+ def split_heads (self , x : torch .Tensor ) -> torch .Tensor :
180173 new_x_shape = x .size ()[:- 1 ] + (self .num_attention_heads , self .attention_head_size )
181- x = x .view (new_x_shape )
182- return x .permute (0 , 2 , 1 , 3 )
174+ return x .view (new_x_shape )
183175
184176 def forward (
185177 self ,
@@ -204,19 +196,19 @@ def forward(
204196 value_layer = past_key_value [1 ]
205197 attention_mask = encoder_attention_mask
206198 elif is_cross_attention :
207- key_layer = self .transpose_for_scores (self .key (encoder_hidden_states ))
208- value_layer = self .transpose_for_scores (self .value (encoder_hidden_states ))
199+ key_layer = self .split_heads (self .key (encoder_hidden_states ))
200+ value_layer = self .split_heads (self .value (encoder_hidden_states ))
209201 attention_mask = encoder_attention_mask
210202 elif past_key_value is not None :
211- key_layer = self .transpose_for_scores (self .key (hidden_states ))
212- value_layer = self .transpose_for_scores (self .value (hidden_states ))
203+ key_layer = self .split_heads (self .key (hidden_states ))
204+ value_layer = self .split_heads (self .value (hidden_states ))
213205 key_layer = torch .cat ([past_key_value [0 ], key_layer ], dim = 2 )
214206 value_layer = torch .cat ([past_key_value [1 ], value_layer ], dim = 2 )
215207 else :
216- key_layer = self .transpose_for_scores (self .key (hidden_states ))
217- value_layer = self .transpose_for_scores (self .value (hidden_states ))
208+ key_layer = self .split_heads (self .key (hidden_states ))
209+ value_layer = self .split_heads (self .value (hidden_states ))
218210
219- query_layer = self .transpose_for_scores (mixed_query_layer )
211+ query_layer = self .split_heads (mixed_query_layer )
220212 # Flash Attention forward pass
221213 attn_output = flash_attention_forward (
222214 query_layer ,
0 commit comments