@@ -556,22 +556,53 @@ def _update_draft_tokens_for_target_inputs(
556556 if target_inputs .next_draft_tokens is None :
557557 return
558558
559- if draft_tensors is not None :
560- for req_idx , request in enumerate (draft_batch .all_requests ()):
561- target_req = self .req_id_to_old_request [request .py_request_id ]
562- if target_req .state != LlmRequestState .GENERATION_IN_PROGRESS :
563- # Skip prefill requests
564- continue
565- # Get the index of the draft/target tokens in the device tensor
566- draft_idx = req_idx if self .use_static_draft_loop else request .py_seq_slot
567- target_idx = target_req .py_seq_slot
568- target_inputs .new_tokens [draft_position + 1 :draft_position +
569- draft_length + 1 , target_idx ,
570- 0 ] = draft_tensors [0 :draft_length ,
571- draft_idx ]
572- target_inputs .next_draft_tokens [
573- target_idx , draft_position :draft_position +
574- draft_length ] = draft_tensors [0 :draft_length , draft_idx ]
559+ draft_indices = []
560+ target_indices = []
561+ for req_idx , request in enumerate (draft_batch .all_requests ()):
562+ target_req = self .req_id_to_old_request [request .py_request_id ]
563+ if target_req .state != LlmRequestState .GENERATION_IN_PROGRESS :
564+ # Skip prefill requests
565+ continue
566+ # Get the index of the draft/target tokens in the device tensor
567+ draft_idx = req_idx if self .use_static_draft_loop else request .py_seq_slot
568+ target_idx = target_req .py_seq_slot
569+ draft_indices .append (draft_idx )
570+ target_indices .append (target_idx )
571+
572+ if len (draft_indices ) == 0 :
573+ return
574+
575+ device = draft_tensors .device
576+
577+ # Create index tensors
578+ draft_indices_tensor = torch .tensor (draft_indices ,
579+ dtype = torch .long ,
580+ pin_memory = True ).to (
581+ device , non_blocking = True )
582+ target_indices_tensor = torch .tensor (target_indices ,
583+ dtype = torch .long ,
584+ pin_memory = True ).to (
585+ device , non_blocking = True )
586+
587+ # Pre-slice draft tensors: [draft_length, batch_size]
588+ draft_slice = draft_tensors [0 :draft_length ]
589+
590+ # Gather all source data at once using single index_select kernel
591+ # Result shape: [draft_length, num_requests]
592+ gathered = draft_slice .index_select (1 , draft_indices_tensor ).to (
593+ torch .int32 )
594+
595+ # Scatter to new_tokens using advanced indexing (single kernel)
596+ # Shape: [draft_length, num_requests] -> [seq_len, batch_size, beam_width]
597+ target_inputs .new_tokens [draft_position + 1 :draft_position +
598+ draft_length + 1 , target_indices_tensor ,
599+ 0 ] = gathered
600+
601+ # Scatter to next_draft_tokens using advanced indexing (single kernel)
602+ # Shape: [num_requests, draft_length] -> [batch_size, max_draft_len]
603+ target_inputs .next_draft_tokens [target_indices_tensor ,
604+ draft_position :draft_position +
605+ draft_length ] = gathered .t ()
575606
576607 def _setup_draft_batch_and_resources (
577608 self ,
0 commit comments