Skip to content

Commit c82a5a1

Browse files
authored
Merge pull request #71 from mayank31398/mqa
Fixed MQA outputs not matching with HF model with non-flash case
2 parents 2223891 + 1809fc1 commit c82a5a1

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

megatron/model/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi):
414414
# alibi: (batch_size * num_attention_heads, 1, max_seq_len)
415415
# TODO: ideally, alibi would have the shape: (1, num_heads * sq, sk)
416416
matmul_input_buffer = alibi[:bs * np, :, :sk].view(bs, np, sk)
417-
matmul_input_buffer = matmul_input_buffer.repeat(1, sq, 1) # [b, np * sq, sk]
417+
matmul_input_buffer = matmul_input_buffer.unsqueeze(2).expand(bs, np, sq, sk).reshape(bs, np * sq, sk) # [b, np * sq, sk]
418418

419419
if alibi is None:
420420
# Raw attention scores. [b, np * sq, sk]

0 commit comments

Comments
 (0)