55from dataclasses import dataclass
66from typing import Optional
77
8- import torch , time
8+ import torch
99
1010from vllm .platforms import current_platform
1111from vllm .sampling_params import SamplingParams , SamplingType
@@ -571,14 +571,6 @@ def from_lists(
571571 prompt_tokens_cache .device == device ):
572572 # Reuse cached prompt_tokens already on HPU
573573 prompt_t = prompt_tokens_cache
574- # Get the last element from each list
575- last_elements = [out [- 1 ] for out in output_tokens ]
576- lengths = [len (out )- 1 for out in output_tokens ]
577- indices = torch .tensor (lengths , device = device )
578- rows = torch .arange (output_tokens_cache .shape [0 ], device = device )
579- # Convert to a PyTorch tensor with shape [4, 1]
580- last_elements_t = torch .tensor (last_elements ).unsqueeze (1 ).to (output_tokens_cache .device )
581- output_t = output_tokens_cache .index_put_ ((rows , indices ), last_elements_t )
582574 else :
583575 prompt_t = make_tensor_with_pad_align (
584576 prompt_tokens ,
@@ -588,6 +580,18 @@ def from_lists(
588580 pin_memory = pin_memory ,
589581 max_len_align = 1024 ,
590582 )
583+ if (output_tokens_cache is not None and
584+ output_tokens_cache .device == device and
585+ len (output_tokens ) > 0 and len (output_tokens_cache [0 ]) > 0 ):
586+ # Get the last element from each list
587+ last_elements = [out [- 1 ] for out in output_tokens ]
588+ lengths = [len (out )- 1 for out in output_tokens ]
589+ indices = torch .tensor (lengths , device = device )
590+ rows = torch .arange (output_tokens_cache .shape [0 ], device = device )
591+ # Convert to a PyTorch tensor with shape [4, 1]
592+ last_elements_t = torch .tensor (last_elements ).unsqueeze (1 ).to (output_tokens_cache .device )
593+ output_t = output_tokens_cache .index_put_ ((rows , indices ), last_elements_t )
594+ else :
591595 output_t = make_tensor_with_pad_align (
592596 output_tokens ,
593597 vocab_size ,
@@ -660,7 +664,6 @@ def from_lists(
660664 )
661665 # Because the memory is pinned, we can do non-blocking
662666 # transfer to device.
663- output_t = output_t .to (device = device , non_blocking = True ) if output_t .device != device else output_t
664667 return cls (
665668 temperatures = temperatures_t .to (device = device , non_blocking = True ),
666669 top_ps = top_ps_t .to (device = device , non_blocking = True ),
@@ -673,5 +676,5 @@ def from_lists(
673676 repetition_penalties = repetition_penalties_t .to (device = device ,
674677 non_blocking = True ),
675678 prompt_tokens = prompt_t .to (device = device , non_blocking = True ) if prompt_t .device != device else prompt_t ,
676- output_tokens = output_t
679+ output_tokens = output_t . to ( device = device , non_blocking = True ) if output_t . device != device else output_t
677680 )
0 commit comments