4949)
5050from vllm .v1 .worker .gpu .mm .mrope_utils import MRopeState
5151from vllm .v1 .worker .gpu .sample .logprob import compute_prompt_logprobs
52- from vllm .v1 .worker .gpu .sample .metadata import SamplingMetadata
5352from vllm .v1 .worker .gpu .sample .output import SamplerOutput
5453from vllm .v1 .worker .gpu .sample .sampler import Sampler
5554from vllm .v1 .worker .gpu .spec_decode import init_speculator
@@ -139,7 +138,12 @@ def __init__(
139138 dtype = self .dtype ,
140139 device = self .device ,
141140 )
142- self .sampler = Sampler (logprobs_mode = self .model_config .logprobs_mode )
141+ self .sampler = Sampler (
142+ max_num_reqs = self .max_num_reqs ,
143+ vocab_size = self .vocab_size ,
144+ device = self .device ,
145+ logprobs_mode = self .model_config .logprobs_mode ,
146+ )
143147
144148 # CUDA graphs.
145149 self .cudagraph_manager = CudaGraphManager (
@@ -310,12 +314,14 @@ def _dummy_sampler_run(
310314 hidden_states : torch .Tensor ,
311315 ) -> None :
312316 num_reqs = hidden_states .shape [0 ]
313- sampling_metadata = SamplingMetadata .make_dummy (
314- num_reqs = num_reqs ,
315- device = self .device ,
316- )
317317 logits = self .model .compute_logits (hidden_states )
318- self .sampler (logits , sampling_metadata )
318+ idx_mapping = torch .arange (num_reqs , dtype = torch .int32 , device = self .device )
319+ idx_mapping_np = np .arange (num_reqs , dtype = np .int32 )
320+ pos = torch .zeros (num_reqs , dtype = torch .int64 , device = self .device )
321+ # NOTE(woosuk): During the initial memory profiling, the sampler may skip
322+ # top_k, top_p, and logprobs, using less GPU memory than what is possible
323+ # during actual execution.
324+ self .sampler (logits , idx_mapping , idx_mapping_np , pos )
319325
320326 @torch .inference_mode ()
321327 def profile_run (self ) -> None :
@@ -401,9 +407,10 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None:
401407 assert new_req_data .prefill_token_ids is not None
402408 assert new_req_data .sampling_params is not None
403409 req_id = new_req_data .req_id
410+ prompt_len = len (new_req_data .prompt_token_ids )
404411 self .req_states .add_request (
405412 req_id = req_id ,
406- prompt_len = len ( new_req_data . prompt_token_ids ) ,
413+ prompt_len = prompt_len ,
407414 prefill_token_ids = new_req_data .prefill_token_ids ,
408415 num_computed_tokens = new_req_data .num_computed_tokens ,
409416 sampling_params = new_req_data .sampling_params ,
@@ -423,6 +430,9 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None:
423430 self .block_tables .append_block_ids (
424431 req_index , new_req_data .block_ids , overwrite = True
425432 )
433+ self .sampler .add_request (
434+ req_index , prompt_len , new_req_data .sampling_params
435+ )
426436
427437 # Add new blocks for the existing requests.
428438 cached_reqs = scheduler_output .scheduled_cached_reqs
@@ -436,6 +446,11 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None:
436446
437447 self .req_states .apply_staged_writes ()
438448 self .block_tables .apply_staged_writes ()
449+ self .sampler .apply_staged_writes (
450+ self .req_states .prefill_token_ids .gpu ,
451+ self .req_states .prefill_len .np ,
452+ self .req_states .prompt_len ,
453+ )
439454 if self .uses_mrope :
440455 self .mrope_states .apply_staged_writes ()
441456
@@ -612,10 +627,10 @@ def sample(
612627 self ,
613628 hidden_states : torch .Tensor ,
614629 input_batch : InputBatch ,
615- sampling_metadata : SamplingMetadata ,
616630 grammar_output : GrammarOutput | None ,
617631 ) -> tuple [SamplerOutput , torch .Tensor , torch .Tensor ]:
618632 sample_hidden_states = hidden_states [input_batch .logits_indices ]
633+ sample_pos = input_batch .positions [input_batch .logits_indices ]
619634 logits = self .model .compute_logits (sample_hidden_states )
620635 if grammar_output is not None :
621636 # Apply grammar bitmask to the logits in-place.
@@ -627,7 +642,12 @@ def sample(
627642 )
628643
629644 # Sample tokens and compute logprobs (if needed).
630- sampler_output = self .sampler (logits , sampling_metadata )
645+ sampler_output = self .sampler (
646+ logits ,
647+ input_batch .expanded_idx_mapping ,
648+ input_batch .idx_mapping_np ,
649+ sample_pos ,
650+ )
631651
632652 if input_batch .num_draft_tokens == 0 :
633653 # No draft tokens (common case).
@@ -766,7 +786,7 @@ def postprocess(
766786 input_batch .idx_mapping ,
767787 self .req_states .num_computed_tokens .gpu ,
768788 self .req_states .last_sampled_tokens ,
769- self .req_states .output_bin_counts ,
789+ self .sampler . penalties_state .output_bin_counts ,
770790 sampled_tokens ,
771791 num_sampled ,
772792 num_rejected ,
@@ -786,7 +806,6 @@ def postprocess(
786806 def propose_draft (
787807 self ,
788808 input_batch : InputBatch ,
789- sampling_metadata : SamplingMetadata ,
790809 last_hidden_states : torch .Tensor ,
791810 aux_hidden_states : list [torch .Tensor ] | None ,
792811 num_sampled : torch .Tensor ,
@@ -801,13 +820,14 @@ def propose_draft(
801820 ]
802821 draft_tokens = self .speculator .propose (
803822 input_batch ,
804- sampling_metadata ,
805823 last_hidden_states ,
806824 aux_hidden_states ,
807825 num_sampled ,
808826 num_rejected ,
809827 last_sampled_tokens ,
810828 next_prefill_tokens ,
829+ self .sampler .sampling_states .temperature .gpu ,
830+ self .sampler .sampling_states .seeds .gpu ,
811831 )
812832 return draft_tokens
813833
@@ -893,12 +913,6 @@ def execute_model(
893913 scheduler_output ,
894914 num_tokens_after_padding ,
895915 )
896-
897- pos = input_batch .positions [input_batch .logits_indices ]
898- sampling_metadata = self .req_states .make_sampling_metadata (
899- input_batch .expanded_idx_mapping , input_batch .idx_mapping_np , pos
900- )
901-
902916 if self .lora_config :
903917 # Activate LoRA adapters.
904918 lora_inputs = self .req_states .make_lora_inputs (
@@ -917,7 +931,6 @@ def execute_model(
917931 device = self .device ,
918932 )
919933 self .prepare_dummy_attn_metadata (input_batch )
920- sampling_metadata = None
921934
922935 # Run model.
923936 if cudagraph_mode == CUDAGraphMode .FULL :
@@ -946,7 +959,7 @@ def execute_model(
946959 positions = positions ,
947960 )
948961
949- self .execute_model_state = hidden_states , input_batch , sampling_metadata
962+ self .execute_model_state = hidden_states , input_batch
950963 return None
951964
952965 @torch .inference_mode ()
@@ -955,12 +968,11 @@ def sample_tokens(
955968 grammar_output : GrammarOutput | None ,
956969 ) -> AsyncOutput | ModelRunnerOutput :
957970 assert self .execute_model_state is not None
958- hidden_states , input_batch , sampling_metadata = self .execute_model_state
971+ hidden_states , input_batch = self .execute_model_state
959972 self .execute_model_state = None # type: ignore
960- assert sampling_metadata is not None
961973
962974 sampler_output , num_sampled , num_rejected = self .sample (
963- hidden_states , input_batch , sampling_metadata , grammar_output
975+ hidden_states , input_batch , grammar_output
964976 )
965977 prompt_logprobs_dict = self .compute_prompt_logprobs (hidden_states , input_batch )
966978
@@ -992,7 +1004,6 @@ def sample_tokens(
9921004 if self .do_spec_decode :
9931005 draft_tokens = self .propose_draft (
9941006 input_batch ,
995- sampling_metadata ,
9961007 hidden_states ,
9971008 None , # aux_hidden_states
9981009 num_sampled ,
0 commit comments