55from dataclasses import dataclass
66from typing import Optional
77
8- import torch
8+ import torch , time
99
1010from vllm .platforms import current_platform
1111from vllm .sampling_params import SamplingParams , SamplingType
@@ -423,6 +423,7 @@ def from_sampling_metadata(
423423 device : torch .device ,
424424 dtype : torch .dtype ,
425425 prompt_tokens_cache : torch .tensor ,
426+ output_tokens_cache : torch .tensor ,
426427 past_seq_ids : set ,
427428 ) -> tuple ["SamplingTensors" , bool , bool , bool , Optional [int ],
428429 Optional [float ], Optional [torch .tensor ]]:
@@ -516,7 +517,7 @@ def from_sampling_metadata(
516517 current_seq_ids .update (seq_ids )
517518 if current_seq_ids != past_seq_ids :
518519 prompt_tokens_cache = None
519-
520+ output_tokens_cache = None
520521 top_k_scalar = top_ks [0 ] if do_top_p_top_k and all (
521522 k == top_ks [0 ] for k in top_ks ) else None
522523 top_p_scalar = top_ps [0 ] if do_top_p_top_k and all (
@@ -536,6 +537,7 @@ def from_sampling_metadata(
536537 device ,
537538 dtype ,
538539 prompt_tokens_cache ,
540+ output_tokens_cache ,
539541 )
540542 return (sampling_tensors , do_penalties , do_top_p_top_k , do_min_p ,
541543 top_k_scalar , top_p_scalar , current_seq_ids )
@@ -556,6 +558,7 @@ def from_lists(
556558 device : torch .device ,
557559 dtype : torch .dtype ,
558560 prompt_tokens_cache : torch .tensor ,
561+ output_tokens_cache : torch .tensor ,
559562 ) -> "SamplingTensors" :
560563 # Note that the performance will be very bad without
561564 # pinned memory.
@@ -568,6 +571,14 @@ def from_lists(
568571 prompt_tokens_cache .device == device ):
569572 # Reuse cached prompt_tokens already on HPU
570573 prompt_t = prompt_tokens_cache
574+ # Get the last element from each list
575+ last_elements = [out [- 1 ] for out in output_tokens ]
576+ lengths = [len (out )- 1 for out in output_tokens ]
577+ indices = torch .tensor (lengths , device = device )
578+ rows = torch .arange (output_tokens_cache .shape [0 ], device = device )
579+ # Convert to a PyTorch tensor with shape [4, 1]
580+ last_elements_t = torch .tensor (last_elements ).unsqueeze (1 ).to (output_tokens_cache .device )
581+ output_t = output_tokens_cache .index_put_ ((rows , indices ), last_elements_t )
571582 else :
572583 prompt_t = make_tensor_with_pad_align (
573584 prompt_tokens ,
@@ -577,14 +588,14 @@ def from_lists(
577588 pin_memory = pin_memory ,
578589 max_len_align = 1024 ,
579590 )
580- output_t = make_tensor_with_pad_align (
581- output_tokens ,
582- vocab_size ,
583- device = "cpu" ,
584- dtype = torch .int64 ,
585- pin_memory = pin_memory ,
586- max_len_align = 1024 ,
587- )
591+ output_t = make_tensor_with_pad_align (
592+ output_tokens ,
593+ vocab_size ,
594+ device = "cpu" ,
595+ dtype = torch .int64 ,
596+ pin_memory = pin_memory ,
597+ max_len_align = 1024 ,
598+ )
588599 else :
589600 prompt_t = make_tensor_with_pad (
590601 prompt_tokens ,
@@ -649,7 +660,7 @@ def from_lists(
649660 )
650661 # Because the memory is pinned, we can do non-blocking
651662 # transfer to device.
652-
663+ output_t = output_t . to ( device = device , non_blocking = True ) if output_t . device != device else output_t
653664 return cls (
654665 temperatures = temperatures_t .to (device = device , non_blocking = True ),
655666 top_ps = top_ps_t .to (device = device , non_blocking = True ),
@@ -662,5 +673,5 @@ def from_lists(
662673 repetition_penalties = repetition_penalties_t .to (device = device ,
663674 non_blocking = True ),
664675 prompt_tokens = prompt_t .to (device = device , non_blocking = True ) if prompt_t .device != device else prompt_t ,
665- output_tokens = output_t . to ( device = device , non_blocking = True ),
676+ output_tokens = output_t
666677 )
0 commit comments