Skip to content

Commit b38c808

Browse files
committed
Merge PR1974 intervl:cache prompt_tokens for sampling metadata
in sampling if do penalties, the prompt_tokens regenerates for each decode, that takes time. instead we can use cache it, and reset when requests set changes
1 parent 65abdfb commit b38c808

File tree

3 files changed

+47
-16
lines changed

3 files changed

+47
-16
lines changed

vllm/model_executor/layers/sampler.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,9 @@ def __init__(self):
197197
# speculative decoding and when prompt embeddings are specified.
198198
self.include_gpu_probs_tensor = False
199199
self.should_modify_greedy_probs_inplace = False
200+
# Add HPU cache class variables
201+
self._prompt_tokens_hpu_cache: Optional[torch.Tensor] = None
202+
self._cached_seq_ids: Optional[set] = None
200203

201204
def _init_sampling_tensors(
202205
self,
@@ -216,8 +219,10 @@ def _init_sampling_tensors(
216219

217220
# Initialize new sampling tensors
218221
(sampling_tensors, do_penalties, do_top_p_top_k, do_min_p,
219-
top_k_scalar, top_p_scalar) = SamplingTensors.from_sampling_metadata(
220-
sampling_metadata, vocab_size, logits.device, logits.dtype)
222+
top_k_scalar, top_p_scalar, current_seq_ids) = \
223+
SamplingTensors.from_sampling_metadata(
224+
sampling_metadata, vocab_size, logits.device, logits.dtype, \
225+
self._prompt_tokens_hpu_cache, self._cached_seq_ids)
221226

222227
self._sampling_tensors = sampling_tensors
223228
self._do_penalties = do_penalties
@@ -227,6 +232,12 @@ def _init_sampling_tensors(
227232
self._top_p_scalar = top_p_scalar
228233

229234
self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5)
235+
# Check if batch composition changed - if so, invalidate prompt cache
236+
237+
# After tensors are created, update cache
238+
if self._cached_seq_ids != current_seq_ids:
239+
self._prompt_tokens_hpu_cache = None
240+
self._cached_seq_ids = current_seq_ids
230241

231242
def forward(
232243
self,

vllm/model_executor/sampling_metadata.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
make_tensor_with_pad_align)
1717

1818
_SAMPLING_EPS = 1e-5
19+
pin_memory = is_pin_memory_available()
20+
is_hpu = current_platform.is_hpu()
1921

2022

2123
@dataclass
@@ -286,7 +288,7 @@ def _prepare_seq_groups(
286288

287289
if seq_group_metadata.is_prompt:
288290
if sampling_params.seed is not None:
289-
if current_platform.is_hpu():
291+
if is_hpu:
290292
import habana_frameworks.torch.hpu.random as htrandom
291293
generator = \
292294
htrandom.default_generators[
@@ -420,8 +422,10 @@ def from_sampling_metadata(
420422
vocab_size: int,
421423
device: torch.device,
422424
dtype: torch.dtype,
425+
prompt_tokens_cache: torch.tensor,
426+
past_seq_ids: set,
423427
) -> tuple["SamplingTensors", bool, bool, bool, Optional[int],
424-
Optional[float]]:
428+
Optional[float], Optional[torch.tensor]]:
425429
prompt_tokens: list[array] = []
426430
output_tokens: list[array] = []
427431
top_ks: list[int] = []
@@ -434,6 +438,7 @@ def from_sampling_metadata(
434438
do_penalties = False
435439
do_top_p_top_k = False
436440
do_min_p = False
441+
current_seq_ids = set()
437442

438443
assert sampling_metadata.seq_groups is not None
439444
for seq_group in sampling_metadata.seq_groups:
@@ -508,6 +513,9 @@ def from_sampling_metadata(
508513
seq_data = seq_group.seq_data[seq_id]
509514
prompt_tokens.append(seq_data.prompt_token_ids_array)
510515
output_tokens.append(seq_data.output_token_ids_array)
516+
current_seq_ids.update(seq_ids)
517+
if current_seq_ids != past_seq_ids:
518+
prompt_tokens_cache = None
511519

512520
top_k_scalar = top_ks[0] if do_top_p_top_k and all(
513521
k == top_ks[0] for k in top_ks) else None
@@ -527,9 +535,10 @@ def from_sampling_metadata(
527535
vocab_size,
528536
device,
529537
dtype,
538+
prompt_tokens_cache,
530539
)
531540
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p,
532-
top_k_scalar, top_p_scalar)
541+
top_k_scalar, top_p_scalar, current_seq_ids)
533542

534543
@classmethod
535544
def from_lists(
@@ -546,23 +555,28 @@ def from_lists(
546555
vocab_size: int,
547556
device: torch.device,
548557
dtype: torch.dtype,
558+
prompt_tokens_cache: torch.tensor,
549559
) -> "SamplingTensors":
550560
# Note that the performance will be very bad without
551561
# pinned memory.
552-
pin_memory = is_pin_memory_available()
553562

554563
do_penalties = prompt_tokens or output_tokens
555564

556565
if do_penalties:
557-
if current_platform.is_hpu():
558-
prompt_t = make_tensor_with_pad_align(
559-
prompt_tokens,
560-
vocab_size,
561-
device="cpu",
562-
dtype=torch.int64,
563-
pin_memory=pin_memory,
564-
max_len_align=1024,
565-
)
566+
if is_hpu:
567+
if (prompt_tokens_cache is not None and
568+
prompt_tokens_cache.device == device):
569+
# Reuse cached prompt_tokens already on HPU
570+
prompt_t = prompt_tokens_cache
571+
else:
572+
prompt_t = make_tensor_with_pad_align(
573+
prompt_tokens,
574+
vocab_size,
575+
device="cpu",
576+
dtype=torch.int64,
577+
pin_memory=pin_memory,
578+
max_len_align=1024,
579+
)
566580
output_t = make_tensor_with_pad_align(
567581
output_tokens,
568582
vocab_size,
@@ -647,6 +661,6 @@ def from_lists(
647661
non_blocking=True),
648662
repetition_penalties=repetition_penalties_t.to(device=device,
649663
non_blocking=True),
650-
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
664+
prompt_tokens=prompt_t.to(device=device, non_blocking=True) if prompt_t.device != device else prompt_t,
651665
output_tokens=output_t.to(device=device, non_blocking=True),
652666
)

vllm/worker/hpu_model_runner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4023,6 +4023,12 @@ def try_revert_dummy_output_tokens():
40234023
self.cached_step_inputs.append(model_input)
40244024
if self.do_mark_step:
40254025
htorch.core.mark_step()
4026+
if hasattr(self.model.sampler, '_sampling_tensors') and \
4027+
self.model.sampler._sampling_tensors is not None:
4028+
sampling_tensors = self.model.sampler._sampling_tensors
4029+
if sampling_tensors.prompt_tokens.numel() > 0:
4030+
# Cache the prompt_tokens tensor that's already on HPU
4031+
self.model.sampler._prompt_tokens_hpu_cache = sampling_tensors.prompt_tokens
40264032
if use_delayed_sampling \
40274033
and model_input.async_callback is not None:
40284034
model_input.async_callback()

0 commit comments

Comments
 (0)