@@ -576,22 +576,53 @@ def _update_draft_tokens_for_target_inputs(
576576 if target_inputs .next_draft_tokens is None :
577577 return
578578
579- if draft_tensors is not None :
580- for req_idx , request in enumerate (draft_batch .all_requests ()):
581- target_req = self .req_id_to_old_request [request .py_request_id ]
582- if target_req .state != LlmRequestState .GENERATION_IN_PROGRESS :
583- # Skip prefill requests
584- continue
585- # Get the index of the draft/target tokens in the device tensor
586- draft_idx = req_idx if self .use_static_draft_loop else request .py_seq_slot
587- target_idx = target_req .py_seq_slot
588- target_inputs .new_tokens [draft_position + 1 :draft_position +
589- draft_length + 1 , target_idx ,
590- 0 ] = draft_tensors [0 :draft_length ,
591- draft_idx ]
592- target_inputs .next_draft_tokens [
593- target_idx , draft_position :draft_position +
594- draft_length ] = draft_tensors [0 :draft_length , draft_idx ]
579+ draft_indices = []
580+ target_indices = []
581+ for req_idx , request in enumerate (draft_batch .all_requests ()):
582+ target_req = self .req_id_to_old_request [request .py_request_id ]
583+ if target_req .state != LlmRequestState .GENERATION_IN_PROGRESS :
584+ # Skip prefill requests
585+ continue
586+ # Get the index of the draft/target tokens in the device tensor
587+ draft_idx = req_idx if self .use_static_draft_loop else request .py_seq_slot
588+ target_idx = target_req .py_seq_slot
589+ draft_indices .append (draft_idx )
590+ target_indices .append (target_idx )
591+
592+ if len (draft_indices ) == 0 :
593+ return
594+
595+ device = draft_tensors .device
596+
597+ # Create index tensors
598+ draft_indices_tensor = torch .tensor (draft_indices ,
599+ dtype = torch .long ,
600+ pin_memory = True ).to (
601+ device , non_blocking = True )
602+ target_indices_tensor = torch .tensor (target_indices ,
603+ dtype = torch .long ,
604+ pin_memory = True ).to (
605+ device , non_blocking = True )
606+
607+ # Pre-slice draft tensors: [draft_length, batch_size]
608+ draft_slice = draft_tensors [0 :draft_length ]
609+
610+ # Gather all source data at once using single index_select kernel
611+ # Result shape: [draft_length, num_requests]
612+ gathered = draft_slice .index_select (1 , draft_indices_tensor ).to (
613+ torch .int32 )
614+
615+ # Scatter to new_tokens using advanced indexing (single kernel)
616+ # Shape: [draft_length, num_requests] -> [seq_len, batch_size, beam_width]
617+ target_inputs .new_tokens [draft_position + 1 :draft_position +
618+ draft_length + 1 , target_indices_tensor ,
619+ 0 ] = gathered
620+
621+ # Scatter to next_draft_tokens using advanced indexing (single kernel)
622+ # Shape: [num_requests, draft_length] -> [batch_size, max_draft_len]
623+ target_inputs .next_draft_tokens [target_indices_tensor ,
624+ draft_position :draft_position +
625+ draft_length ] = gathered .t ()
595626
596627 def _setup_draft_batch_and_resources (
597628 self ,
0 commit comments