Skip to content

Commit ee150b8

Browse files
committed
Fix MLA FP8 padding simulate MTP2 accuracy
1 parent 03284a9 commit ee150b8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vllm/attention/ops/rocm_aiter_mla.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def aiter_mla_decode_fwd(
6363
kv_scale = torch.ones([1], dtype=torch.float, device=kv_buffer.device)
6464
batch_size = q_fp8.shape[0]
6565
q_fp8_padded = torch.ones(batch_size * 2, q_fp8.shape[1], q_fp8.shape[2], dtype=q_fp8.dtype, device=q_fp8.device)
66-
q_fp8_padded[::2] = q_fp8
66+
q_fp8_padded[1::2] = q_fp8
6767
qo_indptr_padded = torch.arange(0, (batch_size + 1) * 2, 2, dtype=qo_indptr.dtype, device=qo_indptr.device)
6868
o_padded = torch.empty((o.shape[0] * 2, o.shape[1], o.shape[2]), dtype=o.dtype, device=o.device).fill_(-1)
6969
max_seqlen_q_new = 2
@@ -88,7 +88,7 @@ def aiter_mla_decode_fwd(
8888
q_scale=q_scale,
8989
kv_scale=kv_scale,
9090
)
91-
o[:] = o_padded[::2] # Extract every second element
91+
o[:] = o_padded[1::2] # Extract every second element
9292
else:
9393
mla_decode_fwd_dispatch(q,
9494
kv_buffer.view(-1, 1, 1, q.shape[-1]),

0 commit comments

Comments
 (0)