Skip to content

Commit 602d2d2

Browse files
yeonsilylibinta
andauthored
Libint/add samplemetatensorcache3 (#1991)
Co-authored-by: Libin Tang <[email protected]>
1 parent 22128e5 commit 602d2d2

File tree

3 files changed

+33
-16
lines changed

3 files changed

+33
-16
lines changed

vllm/model_executor/layers/sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def __init__(self):
199199
self.should_modify_greedy_probs_inplace = False
200200
# Add HPU cache class variables
201201
self._prompt_tokens_hpu_cache: Optional[torch.Tensor] = None
202+
self._output_tokens_hpu_cache: Optional[torch.Tensor] = None
202203
self._cached_seq_ids: Optional[set] = None
203204

204205
def _init_sampling_tensors(
@@ -222,7 +223,7 @@ def _init_sampling_tensors(
222223
top_k_scalar, top_p_scalar, current_seq_ids) = \
223224
SamplingTensors.from_sampling_metadata(
224225
sampling_metadata, vocab_size, logits.device, logits.dtype, \
225-
self._prompt_tokens_hpu_cache, self._cached_seq_ids)
226+
self._prompt_tokens_hpu_cache, self._output_tokens_hpu_cache, self._cached_seq_ids)
226227

227228
self._sampling_tensors = sampling_tensors
228229
self._do_penalties = do_penalties
@@ -237,6 +238,7 @@ def _init_sampling_tensors(
237238
# After tensors are created, update cache
238239
if self._cached_seq_ids != current_seq_ids:
239240
self._prompt_tokens_hpu_cache = None
241+
self._output_tokens_hpu_cache = None
240242
self._cached_seq_ids = current_seq_ids
241243

242244
def forward(

vllm/model_executor/sampling_metadata.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass
66
from typing import Optional
77

8-
import torch
8+
import torch,time
99

1010
from vllm.platforms import current_platform
1111
from vllm.sampling_params import SamplingParams, SamplingType
@@ -423,6 +423,7 @@ def from_sampling_metadata(
423423
device: torch.device,
424424
dtype: torch.dtype,
425425
prompt_tokens_cache: torch.tensor,
426+
output_tokens_cache: torch.tensor,
426427
past_seq_ids: set,
427428
) -> tuple["SamplingTensors", bool, bool, bool, Optional[int],
428429
Optional[float], Optional[torch.tensor]]:
@@ -516,7 +517,7 @@ def from_sampling_metadata(
516517
current_seq_ids.update(seq_ids)
517518
if current_seq_ids != past_seq_ids:
518519
prompt_tokens_cache = None
519-
520+
output_tokens_cache = None
520521
top_k_scalar = top_ks[0] if do_top_p_top_k and all(
521522
k == top_ks[0] for k in top_ks) else None
522523
top_p_scalar = top_ps[0] if do_top_p_top_k and all(
@@ -536,6 +537,7 @@ def from_sampling_metadata(
536537
device,
537538
dtype,
538539
prompt_tokens_cache,
540+
output_tokens_cache,
539541
)
540542
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p,
541543
top_k_scalar, top_p_scalar, current_seq_ids)
@@ -556,6 +558,7 @@ def from_lists(
556558
device: torch.device,
557559
dtype: torch.dtype,
558560
prompt_tokens_cache: torch.tensor,
561+
output_tokens_cache: torch.tensor,
559562
) -> "SamplingTensors":
560563
# Note that the performance will be very bad without
561564
# pinned memory.
@@ -568,6 +571,14 @@ def from_lists(
568571
prompt_tokens_cache.device == device):
569572
# Reuse cached prompt_tokens already on HPU
570573
prompt_t = prompt_tokens_cache
574+
# Get the last element from each list
575+
last_elements = [out[-1] for out in output_tokens]
576+
lengths = [len(out)-1 for out in output_tokens]
577+
indices = torch.tensor(lengths, device=device)
578+
rows = torch.arange(output_tokens_cache.shape[0], device=device)
579+
# Convert to a PyTorch tensor with shape [4, 1]
580+
last_elements_t = torch.tensor(last_elements).unsqueeze(1).to(output_tokens_cache.device)
581+
output_t = output_tokens_cache.index_put_((rows, indices), last_elements_t)
571582
else:
572583
prompt_t = make_tensor_with_pad_align(
573584
prompt_tokens,
@@ -577,14 +588,14 @@ def from_lists(
577588
pin_memory=pin_memory,
578589
max_len_align=1024,
579590
)
580-
output_t = make_tensor_with_pad_align(
581-
output_tokens,
582-
vocab_size,
583-
device="cpu",
584-
dtype=torch.int64,
585-
pin_memory=pin_memory,
586-
max_len_align=1024,
587-
)
591+
output_t = make_tensor_with_pad_align(
592+
output_tokens,
593+
vocab_size,
594+
device="cpu",
595+
dtype=torch.int64,
596+
pin_memory=pin_memory,
597+
max_len_align=1024,
598+
)
588599
else:
589600
prompt_t = make_tensor_with_pad(
590601
prompt_tokens,
@@ -649,7 +660,7 @@ def from_lists(
649660
)
650661
# Because the memory is pinned, we can do non-blocking
651662
# transfer to device.
652-
663+
output_t=output_t.to(device=device, non_blocking=True) if output_t.device != device else output_t
653664
return cls(
654665
temperatures=temperatures_t.to(device=device, non_blocking=True),
655666
top_ps=top_ps_t.to(device=device, non_blocking=True),
@@ -662,5 +673,5 @@ def from_lists(
662673
repetition_penalties=repetition_penalties_t.to(device=device,
663674
non_blocking=True),
664675
prompt_tokens=prompt_t.to(device=device, non_blocking=True) if prompt_t.device != device else prompt_t,
665-
output_tokens=output_t.to(device=device, non_blocking=True),
676+
output_tokens=output_t
666677
)

vllm/worker/hpu_model_runner.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,12 +1888,15 @@ def _prepare_prompt(
18881888
if image_index_tensor is not None:
18891889
multi_modal_kwargs['image_index'] = image_index_tensor
18901890

1891-
use_mediapipe = os.getenv("VLLM_USE_MEDIA_PIPELINE", "false").lower() in ("1", "true", "yes")
1891+
use_mediapipe = os.getenv("VLLM_USE_MEDIA_PIPELINE",
1892+
"false").lower() in ("1", "true", "yes")
18921893
if use_mediapipe:
18931894
# With mediapipe path some tensors will already be on HPU, we only move to HPU if needed
18941895
for key in multi_modal_kwargs.keys():
1895-
if hasattr(multi_modal_kwargs[key], "device") and multi_modal_kwargs[key].device != self.device:
1896-
multi_modal_kwargs[key] = self.move_to_device(multi_modal_kwargs[key])
1896+
if hasattr(multi_modal_kwargs[key], "device"
1897+
) and multi_modal_kwargs[key].device != self.device:
1898+
multi_modal_kwargs[key] = self.move_to_device(
1899+
multi_modal_kwargs[key])
18971900
else:
18981901
multi_modal_kwargs = MultiModalKwargs.as_kwargs(multi_modal_kwargs,
18991902
device=self.device)
@@ -4030,6 +4033,7 @@ def try_revert_dummy_output_tokens():
40304033
if sampling_tensors.prompt_tokens.numel() > 0:
40314034
# Cache the prompt_tokens tensor that's already on HPU
40324035
self.model.sampler._prompt_tokens_hpu_cache = sampling_tensors.prompt_tokens
4036+
self.model.sampler._output_tokens_hpu_cache = sampling_tensors.output_tokens
40334037
if use_delayed_sampling \
40344038
and model_input.async_callback is not None:
40354039
model_input.async_callback()

0 commit comments

Comments
 (0)