Skip to content

Commit 90c0836

Browse files
authored
[Model Runner V2] Refactor Sampler (vllm-project#32245)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 8ef50d9 commit 90c0836

File tree

7 files changed

+289
-269
lines changed

7 files changed

+289
-269
lines changed

vllm/v1/worker/gpu/model_runner.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
)
5050
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
5151
from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs
52-
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
5352
from vllm.v1.worker.gpu.sample.output import SamplerOutput
5453
from vllm.v1.worker.gpu.sample.sampler import Sampler
5554
from vllm.v1.worker.gpu.spec_decode import init_speculator
@@ -139,7 +138,12 @@ def __init__(
139138
dtype=self.dtype,
140139
device=self.device,
141140
)
142-
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
141+
self.sampler = Sampler(
142+
max_num_reqs=self.max_num_reqs,
143+
vocab_size=self.vocab_size,
144+
device=self.device,
145+
logprobs_mode=self.model_config.logprobs_mode,
146+
)
143147

144148
# CUDA graphs.
145149
self.cudagraph_manager = CudaGraphManager(
@@ -310,12 +314,14 @@ def _dummy_sampler_run(
310314
hidden_states: torch.Tensor,
311315
) -> None:
312316
num_reqs = hidden_states.shape[0]
313-
sampling_metadata = SamplingMetadata.make_dummy(
314-
num_reqs=num_reqs,
315-
device=self.device,
316-
)
317317
logits = self.model.compute_logits(hidden_states)
318-
self.sampler(logits, sampling_metadata)
318+
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device)
319+
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
320+
pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device)
321+
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
322+
# top_k, top_p, and logprobs, using less GPU memory than what is possible
323+
# during actual execution.
324+
self.sampler(logits, idx_mapping, idx_mapping_np, pos)
319325

320326
@torch.inference_mode()
321327
def profile_run(self) -> None:
@@ -401,9 +407,10 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None:
401407
assert new_req_data.prefill_token_ids is not None
402408
assert new_req_data.sampling_params is not None
403409
req_id = new_req_data.req_id
410+
prompt_len = len(new_req_data.prompt_token_ids)
404411
self.req_states.add_request(
405412
req_id=req_id,
406-
prompt_len=len(new_req_data.prompt_token_ids),
413+
prompt_len=prompt_len,
407414
prefill_token_ids=new_req_data.prefill_token_ids,
408415
num_computed_tokens=new_req_data.num_computed_tokens,
409416
sampling_params=new_req_data.sampling_params,
@@ -423,6 +430,9 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None:
423430
self.block_tables.append_block_ids(
424431
req_index, new_req_data.block_ids, overwrite=True
425432
)
433+
self.sampler.add_request(
434+
req_index, prompt_len, new_req_data.sampling_params
435+
)
426436

427437
# Add new blocks for the existing requests.
428438
cached_reqs = scheduler_output.scheduled_cached_reqs
@@ -436,6 +446,11 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None:
436446

437447
self.req_states.apply_staged_writes()
438448
self.block_tables.apply_staged_writes()
449+
self.sampler.apply_staged_writes(
450+
self.req_states.prefill_token_ids.gpu,
451+
self.req_states.prefill_len.np,
452+
self.req_states.prompt_len,
453+
)
439454
if self.uses_mrope:
440455
self.mrope_states.apply_staged_writes()
441456

@@ -612,10 +627,10 @@ def sample(
612627
self,
613628
hidden_states: torch.Tensor,
614629
input_batch: InputBatch,
615-
sampling_metadata: SamplingMetadata,
616630
grammar_output: GrammarOutput | None,
617631
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
618632
sample_hidden_states = hidden_states[input_batch.logits_indices]
633+
sample_pos = input_batch.positions[input_batch.logits_indices]
619634
logits = self.model.compute_logits(sample_hidden_states)
620635
if grammar_output is not None:
621636
# Apply grammar bitmask to the logits in-place.
@@ -627,7 +642,12 @@ def sample(
627642
)
628643

629644
# Sample tokens and compute logprobs (if needed).
630-
sampler_output = self.sampler(logits, sampling_metadata)
645+
sampler_output = self.sampler(
646+
logits,
647+
input_batch.expanded_idx_mapping,
648+
input_batch.idx_mapping_np,
649+
sample_pos,
650+
)
631651

632652
if input_batch.num_draft_tokens == 0:
633653
# No draft tokens (common case).
@@ -766,7 +786,7 @@ def postprocess(
766786
input_batch.idx_mapping,
767787
self.req_states.num_computed_tokens.gpu,
768788
self.req_states.last_sampled_tokens,
769-
self.req_states.output_bin_counts,
789+
self.sampler.penalties_state.output_bin_counts,
770790
sampled_tokens,
771791
num_sampled,
772792
num_rejected,
@@ -786,7 +806,6 @@ def postprocess(
786806
def propose_draft(
787807
self,
788808
input_batch: InputBatch,
789-
sampling_metadata: SamplingMetadata,
790809
last_hidden_states: torch.Tensor,
791810
aux_hidden_states: list[torch.Tensor] | None,
792811
num_sampled: torch.Tensor,
@@ -801,13 +820,14 @@ def propose_draft(
801820
]
802821
draft_tokens = self.speculator.propose(
803822
input_batch,
804-
sampling_metadata,
805823
last_hidden_states,
806824
aux_hidden_states,
807825
num_sampled,
808826
num_rejected,
809827
last_sampled_tokens,
810828
next_prefill_tokens,
829+
self.sampler.sampling_states.temperature.gpu,
830+
self.sampler.sampling_states.seeds.gpu,
811831
)
812832
return draft_tokens
813833

@@ -893,12 +913,6 @@ def execute_model(
893913
scheduler_output,
894914
num_tokens_after_padding,
895915
)
896-
897-
pos = input_batch.positions[input_batch.logits_indices]
898-
sampling_metadata = self.req_states.make_sampling_metadata(
899-
input_batch.expanded_idx_mapping, input_batch.idx_mapping_np, pos
900-
)
901-
902916
if self.lora_config:
903917
# Activate LoRA adapters.
904918
lora_inputs = self.req_states.make_lora_inputs(
@@ -917,7 +931,6 @@ def execute_model(
917931
device=self.device,
918932
)
919933
self.prepare_dummy_attn_metadata(input_batch)
920-
sampling_metadata = None
921934

922935
# Run model.
923936
if cudagraph_mode == CUDAGraphMode.FULL:
@@ -946,7 +959,7 @@ def execute_model(
946959
positions=positions,
947960
)
948961

949-
self.execute_model_state = hidden_states, input_batch, sampling_metadata
962+
self.execute_model_state = hidden_states, input_batch
950963
return None
951964

952965
@torch.inference_mode()
@@ -955,12 +968,11 @@ def sample_tokens(
955968
grammar_output: GrammarOutput | None,
956969
) -> AsyncOutput | ModelRunnerOutput:
957970
assert self.execute_model_state is not None
958-
hidden_states, input_batch, sampling_metadata = self.execute_model_state
971+
hidden_states, input_batch = self.execute_model_state
959972
self.execute_model_state = None # type: ignore
960-
assert sampling_metadata is not None
961973

962974
sampler_output, num_sampled, num_rejected = self.sample(
963-
hidden_states, input_batch, sampling_metadata, grammar_output
975+
hidden_states, input_batch, grammar_output
964976
)
965977
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
966978

@@ -992,7 +1004,6 @@ def sample_tokens(
9921004
if self.do_spec_decode:
9931005
draft_tokens = self.propose_draft(
9941006
input_batch,
995-
sampling_metadata,
9961007
hidden_states,
9971008
None, # aux_hidden_states
9981009
num_sampled,

vllm/v1/worker/gpu/sample/metadata.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

0 commit comments

Comments
 (0)