Skip to content

Commit 6054686

Browse files
committed
updated the logic for splitting heads
1 parent c9b28f0 commit 6054686

File tree

1 file changed

+14
-22
lines changed

1 file changed

+14
-22
lines changed

src/cehrbert/models/hf_models/hf_cehrbert.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,12 @@ def flash_attention_forward(
5151
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
5252
is_causal (`bool`, *optional*):
5353
"""
54-
55-
# Flash attention requires the input to have the shape
56-
# batch_size x seq_length x head_dim x hidden_dim
57-
# therefore we just need to keep the original shape
58-
dtype = query_states.dtype
59-
query_states = query_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
60-
key_states = key_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
61-
value_states = value_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
62-
54+
batch_size = query_states.shape[0]
55+
query_states = query_states.to(torch.bfloat16)
56+
key_states = key_states.to(torch.bfloat16)
57+
value_states = value_states.to(torch.bfloat16)
6358
# Contains at least one padding token in the sequence
6459
if attention_mask is not None:
65-
batch_size = query_states.shape[0]
6660
(
6761
query_states,
6862
key_states,
@@ -98,8 +92,7 @@ def flash_attention_forward(
9892
softmax_scale=softmax_scale,
9993
causal=is_causal,
10094
)
101-
# re-order the tensor back to (batch, n_heads, seq_length, head_dim)
102-
return attn_output.permute(0, 2, 1, 3).contiguous().to(dtype)
95+
return attn_output.reshape(batch_size, query_length, -1)
10396

10497

10598
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
@@ -176,10 +169,9 @@ def __init__(self, config, position_embedding_type=None):
176169

177170
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
178171

179-
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
172+
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
180173
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
181-
x = x.view(new_x_shape)
182-
return x.permute(0, 2, 1, 3)
174+
return x.view(new_x_shape)
183175

184176
def forward(
185177
self,
@@ -204,19 +196,19 @@ def forward(
204196
value_layer = past_key_value[1]
205197
attention_mask = encoder_attention_mask
206198
elif is_cross_attention:
207-
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
208-
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
199+
key_layer = self.split_heads(self.key(encoder_hidden_states))
200+
value_layer = self.split_heads(self.value(encoder_hidden_states))
209201
attention_mask = encoder_attention_mask
210202
elif past_key_value is not None:
211-
key_layer = self.transpose_for_scores(self.key(hidden_states))
212-
value_layer = self.transpose_for_scores(self.value(hidden_states))
203+
key_layer = self.split_heads(self.key(hidden_states))
204+
value_layer = self.split_heads(self.value(hidden_states))
213205
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
214206
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
215207
else:
216-
key_layer = self.transpose_for_scores(self.key(hidden_states))
217-
value_layer = self.transpose_for_scores(self.value(hidden_states))
208+
key_layer = self.split_heads(self.key(hidden_states))
209+
value_layer = self.split_heads(self.value(hidden_states))
218210

219-
query_layer = self.transpose_for_scores(mixed_query_layer)
211+
query_layer = self.split_heads(mixed_query_layer)
220212
# Flash Attention forward pass
221213
attn_output = flash_attention_forward(
222214
query_layer,

0 commit comments

Comments
 (0)