Skip to content

Commit 8e88b00

Browse files
committed
add fix for output_token length check
1 parent 602d2d2 commit 8e88b00

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

vllm/model_executor/sampling_metadata.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass
66
from typing import Optional
77

8-
import torch,time
8+
import torch
99

1010
from vllm.platforms import current_platform
1111
from vllm.sampling_params import SamplingParams, SamplingType
@@ -571,14 +571,6 @@ def from_lists(
571571
prompt_tokens_cache.device == device):
572572
# Reuse cached prompt_tokens already on HPU
573573
prompt_t = prompt_tokens_cache
574-
# Get the last element from each list
575-
last_elements = [out[-1] for out in output_tokens]
576-
lengths = [len(out)-1 for out in output_tokens]
577-
indices = torch.tensor(lengths, device=device)
578-
rows = torch.arange(output_tokens_cache.shape[0], device=device)
579-
# Convert to a PyTorch tensor with shape [4, 1]
580-
last_elements_t = torch.tensor(last_elements).unsqueeze(1).to(output_tokens_cache.device)
581-
output_t = output_tokens_cache.index_put_((rows, indices), last_elements_t)
582574
else:
583575
prompt_t = make_tensor_with_pad_align(
584576
prompt_tokens,
@@ -588,6 +580,18 @@ def from_lists(
588580
pin_memory=pin_memory,
589581
max_len_align=1024,
590582
)
583+
if (output_tokens_cache is not None and
584+
output_tokens_cache.device == device and
585+
len(output_tokens) > 0 and len(output_tokens_cache[0]) > 0):
586+
# Get the last element from each list
587+
last_elements = [out[-1] for out in output_tokens]
588+
lengths = [len(out)-1 for out in output_tokens]
589+
indices = torch.tensor(lengths, device=device)
590+
rows = torch.arange(output_tokens_cache.shape[0], device=device)
591+
# Convert to a PyTorch tensor with shape [4, 1]
592+
last_elements_t = torch.tensor(last_elements).unsqueeze(1).to(output_tokens_cache.device)
593+
output_t = output_tokens_cache.index_put_((rows, indices), last_elements_t)
594+
else:
591595
output_t = make_tensor_with_pad_align(
592596
output_tokens,
593597
vocab_size,
@@ -660,7 +664,6 @@ def from_lists(
660664
)
661665
# Because the memory is pinned, we can do non-blocking
662666
# transfer to device.
663-
output_t=output_t.to(device=device, non_blocking=True) if output_t.device != device else output_t
664667
return cls(
665668
temperatures=temperatures_t.to(device=device, non_blocking=True),
666669
top_ps=top_ps_t.to(device=device, non_blocking=True),
@@ -673,5 +676,5 @@ def from_lists(
673676
repetition_penalties=repetition_penalties_t.to(device=device,
674677
non_blocking=True),
675678
prompt_tokens=prompt_t.to(device=device, non_blocking=True) if prompt_t.device != device else prompt_t,
676-
output_tokens=output_t
679+
output_tokens=output_t.to(device=device, non_blocking=True) if output_t.device != device else output_t
677680
)

0 commit comments

Comments
 (0)