Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 0af3abe

Browse files
authored
[TPU][Bugfix] Fix next_token_ids shape (vllm-project#8128)
1 parent f1575dc commit 0af3abe

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

vllm/worker/tpu_model_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ def _execute_model(*args):
601601
batch_idx += 1
602602
else:
603603
for seq_id in seq_ids:
604-
next_token_id = next_token_ids[batch_idx][0]
604+
next_token_id = next_token_ids[batch_idx]
605605
seq_outputs.append(
606606
SequenceOutput(seq_id, next_token_id,
607607
{next_token_id: zero_logprob}))
@@ -722,6 +722,9 @@ def forward(
722722
sampled_token_ids = torch.multinomial(probs,
723723
num_samples,
724724
replacement=True)
725+
if num_samples == 1:
726+
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
727+
sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
725728
next_token_ids = torch.where(t != 0, sampled_token_ids,
726729
argmax_token_ids)
727730
return next_token_ids

0 commit comments

Comments
 (0)