Skip to content

Commit 8b31d1c

Browse files
committed
An optimization to the kernel
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent c9457ef commit 8b31d1c

File tree

1 file changed

+47
-16
lines changed

1 file changed

+47
-16
lines changed

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -576,22 +576,53 @@ def _update_draft_tokens_for_target_inputs(
576576
if target_inputs.next_draft_tokens is None:
577577
return
578578

579-
if draft_tensors is not None:
580-
for req_idx, request in enumerate(draft_batch.all_requests()):
581-
target_req = self.req_id_to_old_request[request.py_request_id]
582-
if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
583-
# Skip prefill requests
584-
continue
585-
# Get the index of the draft/target tokens in the device tensor
586-
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
587-
target_idx = target_req.py_seq_slot
588-
target_inputs.new_tokens[draft_position + 1:draft_position +
589-
draft_length + 1, target_idx,
590-
0] = draft_tensors[0:draft_length,
591-
draft_idx]
592-
target_inputs.next_draft_tokens[
593-
target_idx, draft_position:draft_position +
594-
draft_length] = draft_tensors[0:draft_length, draft_idx]
579+
draft_indices = []
580+
target_indices = []
581+
for req_idx, request in enumerate(draft_batch.all_requests()):
582+
target_req = self.req_id_to_old_request[request.py_request_id]
583+
if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
584+
# Skip prefill requests
585+
continue
586+
# Get the index of the draft/target tokens in the device tensor
587+
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
588+
target_idx = target_req.py_seq_slot
589+
draft_indices.append(draft_idx)
590+
target_indices.append(target_idx)
591+
592+
if len(draft_indices) == 0:
593+
return
594+
595+
device = draft_tensors.device
596+
597+
# Create index tensors
598+
draft_indices_tensor = torch.tensor(draft_indices,
599+
dtype=torch.long,
600+
pin_memory=True).to(
601+
device, non_blocking=True)
602+
target_indices_tensor = torch.tensor(target_indices,
603+
dtype=torch.long,
604+
pin_memory=True).to(
605+
device, non_blocking=True)
606+
607+
# Pre-slice draft tensors: [draft_length, batch_size]
608+
draft_slice = draft_tensors[0:draft_length]
609+
610+
# Gather all source data at once using single index_select kernel
611+
# Result shape: [draft_length, num_requests]
612+
gathered = draft_slice.index_select(1, draft_indices_tensor).to(
613+
torch.int32)
614+
615+
# Scatter to new_tokens using advanced indexing (single kernel)
616+
# Shape: [draft_length, num_requests] -> [seq_len, batch_size, beam_width]
617+
target_inputs.new_tokens[draft_position + 1:draft_position +
618+
draft_length + 1, target_indices_tensor,
619+
0] = gathered
620+
621+
# Scatter to next_draft_tokens using advanced indexing (single kernel)
622+
# Shape: [num_requests, draft_length] -> [batch_size, max_draft_len]
623+
target_inputs.next_draft_tokens[target_indices_tensor,
624+
draft_position:draft_position +
625+
draft_length] = gathered.t()
595626

596627
def _setup_draft_batch_and_resources(
597628
self,

0 commit comments

Comments
 (0)