@@ -231,6 +231,12 @@ def _prepare_draft_batch(
231231 ScheduledRequests: The prepared draft batch
232232 """
233233 try :
234+ for req in scheduled_requests .all_requests ():
235+ draft_model = self .draft_model_engine .model .draft_model if self .use_static_draft_loop else self .draft_model_engine .model
236+ if hasattr (draft_model .model , "d2t" ):
237+ req .d2t = draft_model .model .d2t .data
238+ req .py_draft_use_greedy_sampling = self .use_static_draft_loop
239+
234240 draft_batch = ScheduledRequests ()
235241
236242 for request in scheduled_requests .context_requests :
@@ -530,7 +536,8 @@ def _setup_draft_batch_and_resources(
530536 return draft_batch , req_id_to_old_request
531537
532538 def process_static_draft_outputs (
533- self , outputs : torch .Tensor | SampleState ,
539+ self ,
540+ outputs : dict [str , torch .Tensor ] | tuple [torch .Tensor , SampleState ],
534541 draft_batch : ScheduledRequests ,
535542 req_id_to_old_request : Dict [int , LlmRequest ]) -> None :
536543 """
@@ -541,23 +548,26 @@ def process_static_draft_outputs(
541548 draft_batch: The draft batch that was processed
542549 req_id_to_old_request: Mapping from draft request ID to original request
543550 """
544- if isinstance (outputs , torch .Tensor ):
545- # For non-overlap scheduler path.
546- outputs_host = outputs .cpu ()
551+
552+ if isinstance (outputs , dict ):
553+ draft_tokens_host = outputs ["new_draft_tokens" ].cpu ()
554+ draft_logits = outputs ["draft_logits" ]
547555 else :
548- outputs_host = outputs .host .new_tokens
549- outputs .sampler_event .synchronize ()
550-
551- for token_idx in range (self .max_draft_tokens ):
552- for req_idx , req in enumerate (draft_batch .all_requests ()):
553- target_model_req = req_id_to_old_request [req .py_request_id ]
554- if target_model_req .state != LlmRequestState .GENERATION_IN_PROGRESS :
555- # Chunked prefill request in progress; no need to append draft tokens
556- continue
556+ draft_logits = outputs [0 ]
557+ draft_tokens_host = outputs [1 ].host .new_tokens
558+ outputs [1 ].sampler_event .synchronize ()
557559
558- target_req = req_id_to_old_request [req .py_request_id ]
559- target_req .py_draft_tokens .append (
560- outputs_host [token_idx ][req_idx ])
560+ for req_idx , req in enumerate (draft_batch .all_requests ()):
561+ target_model_req = req_id_to_old_request [req .py_request_id ]
562+ if target_model_req .state != LlmRequestState .GENERATION_IN_PROGRESS :
563+ # Chunked prefill request in progress; no need to append draft tokens
564+ continue
565+ py_draft_logits = []
566+ for token_idx in range (self .max_draft_tokens ):
567+ target_model_req .py_draft_tokens .append (
568+ draft_tokens_host [token_idx ][req_idx ])
569+ py_draft_logits .append (draft_logits [token_idx ][req_idx ])
570+ target_model_req .py_draft_logits = torch .stack (py_draft_logits )
561571
562572 # Clean up draft resources
563573 for req in draft_batch .all_requests ():
@@ -710,23 +720,26 @@ def generate_draft_tokens_with_overlap(
710720 # Only update target inputs, cleanup will be done in executor loop
711721 self ._update_target_inputs_with_draft_tokens (
712722 target_inputs ,
713- outputs ,
723+ outputs [ "new_draft_tokens" ] ,
714724 draft_position = 0 ,
715725 draft_length = self .max_draft_tokens ,
716726 draft_batch = draft_batch ,
717727 req_id_to_old_request = req_id_to_old_request )
718728
719- new_tokens_host = outputs .to (device = 'cpu' , non_blocking = True )
729+ new_tokens_host = outputs ["new_draft_tokens" ].to (device = 'cpu' ,
730+ non_blocking = True )
720731 sampler_event = torch .cuda .Event ()
721732 sampler_event .record ()
722733
723- outputs = SampleState (
734+ sample_state = SampleState (
724735 scheduled_requests = draft_batch ,
725- device = SampleStateTensors (new_tokens = outputs ),
736+ device = SampleStateTensors (
737+ new_tokens = outputs ["new_draft_tokens" ]),
726738 host = SampleStateTensors (new_tokens = new_tokens_host ),
727739 sampler_event = sampler_event )
728740
729- return target_inputs , outputs , draft_batch
741+ return target_inputs , (outputs ["draft_logits" ],
742+ sample_state ), draft_batch
730743
731744 # Handle guided decoder and sampling for non-static loop
732745 if self .guided_decoder is not None :
0 commit comments