Skip to content

Commit a993f05

Browse files
committed
outputs not matching non-flash case in MQA
1 parent 8c1889e commit a993f05

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

megatron/model/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi):
414414
# alibi: (batch_size * num_attention_heads, 1, max_seq_len)
415415
# TODO: ideally, alibi would have the shape: (1, num_heads * sq, sk)
416416
matmul_input_buffer = alibi[:bs * np, :, :sk].view(bs, np, sk)
417-
matmul_input_buffer = matmul_input_buffer.repeat(1, sq, 1) # [b, np * sq, sk]
417+
matmul_input_buffer = matmul_input_buffer.unsqueeze(2).expand(bs, np, sq, sk).reshape(bs, np * sq, sk) # [b, np * sq, sk]
418418

419419
if alibi is None:
420420
# Raw attention scores. [b, np * sq, sk]

0 commit comments

Comments
 (0)