Skip to content

Commit f6e5bd1

Browse files
committed
updated the logic for splitting heads
1 parent c9b28f0 commit f6e5bd1

File tree

1 file changed

+14
-24
lines changed

1 file changed

+14
-24
lines changed

src/cehrbert/models/hf_models/hf_cehrbert.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def flash_attention_forward(
2626
key_states,
2727
value_states,
2828
attention_mask,
29-
query_length,
3029
dropout=0.0,
3130
softmax_scale=None,
3231
is_causal=False,
@@ -51,18 +50,12 @@ def flash_attention_forward(
5150
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
5251
is_causal (`bool`, *optional*):
5352
"""
54-
55-
# Flash attention requires the input to have the shape
56-
# batch_size x seq_length x head_dim x hidden_dim
57-
# therefore we just need to keep the original shape
58-
dtype = query_states.dtype
59-
query_states = query_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
60-
key_states = key_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
61-
value_states = value_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
62-
53+
batch_size, query_length = query_states.shape[:2]
54+
query_states = query_states.to(torch.bfloat16)
55+
key_states = key_states.to(torch.bfloat16)
56+
value_states = value_states.to(torch.bfloat16)
6357
# Contains at least one padding token in the sequence
6458
if attention_mask is not None:
65-
batch_size = query_states.shape[0]
6659
(
6760
query_states,
6861
key_states,
@@ -98,8 +91,7 @@ def flash_attention_forward(
9891
softmax_scale=softmax_scale,
9992
causal=is_causal,
10093
)
101-
# re-order the tensor back to (batch, n_heads, seq_length, head_dim)
102-
return attn_output.permute(0, 2, 1, 3).contiguous().to(dtype)
94+
return attn_output.reshape(batch_size, query_length, -1)
10395

10496

10597
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
@@ -176,10 +168,9 @@ def __init__(self, config, position_embedding_type=None):
176168

177169
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
178170

179-
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
171+
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
180172
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
181-
x = x.view(new_x_shape)
182-
return x.permute(0, 2, 1, 3)
173+
return x.view(new_x_shape)
183174

184175
def forward(
185176
self,
@@ -204,26 +195,25 @@ def forward(
204195
value_layer = past_key_value[1]
205196
attention_mask = encoder_attention_mask
206197
elif is_cross_attention:
207-
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
208-
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
198+
key_layer = self.split_heads(self.key(encoder_hidden_states))
199+
value_layer = self.split_heads(self.value(encoder_hidden_states))
209200
attention_mask = encoder_attention_mask
210201
elif past_key_value is not None:
211-
key_layer = self.transpose_for_scores(self.key(hidden_states))
212-
value_layer = self.transpose_for_scores(self.value(hidden_states))
202+
key_layer = self.split_heads(self.key(hidden_states))
203+
value_layer = self.split_heads(self.value(hidden_states))
213204
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
214205
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
215206
else:
216-
key_layer = self.transpose_for_scores(self.key(hidden_states))
217-
value_layer = self.transpose_for_scores(self.value(hidden_states))
207+
key_layer = self.split_heads(self.key(hidden_states))
208+
value_layer = self.split_heads(self.value(hidden_states))
218209

219-
query_layer = self.transpose_for_scores(mixed_query_layer)
210+
query_layer = self.split_heads(mixed_query_layer)
220211
# Flash Attention forward pass
221212
attn_output = flash_attention_forward(
222213
query_layer,
223214
key_layer,
224215
value_layer,
225216
attention_mask,
226-
query_layer.size(-2),
227217
self.dropout.p,
228218
softmax_scale=None,
229219
is_causal=False,

0 commit comments

Comments
 (0)