Skip to content

Commit 8bc701f

Browse files
committed
An optimization to the kernel
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 096fb4f commit 8bc701f

File tree

1 file changed

+47
-16
lines changed

1 file changed

+47
-16
lines changed

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -556,22 +556,53 @@ def _update_draft_tokens_for_target_inputs(
556556
if target_inputs.next_draft_tokens is None:
557557
return
558558

559-
if draft_tensors is not None:
560-
for req_idx, request in enumerate(draft_batch.all_requests()):
561-
target_req = self.req_id_to_old_request[request.py_request_id]
562-
if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
563-
# Skip prefill requests
564-
continue
565-
# Get the index of the draft/target tokens in the device tensor
566-
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
567-
target_idx = target_req.py_seq_slot
568-
target_inputs.new_tokens[draft_position + 1:draft_position +
569-
draft_length + 1, target_idx,
570-
0] = draft_tensors[0:draft_length,
571-
draft_idx]
572-
target_inputs.next_draft_tokens[
573-
target_idx, draft_position:draft_position +
574-
draft_length] = draft_tensors[0:draft_length, draft_idx]
559+
draft_indices = []
560+
target_indices = []
561+
for req_idx, request in enumerate(draft_batch.all_requests()):
562+
target_req = self.req_id_to_old_request[request.py_request_id]
563+
if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
564+
# Skip prefill requests
565+
continue
566+
# Get the index of the draft/target tokens in the device tensor
567+
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
568+
target_idx = target_req.py_seq_slot
569+
draft_indices.append(draft_idx)
570+
target_indices.append(target_idx)
571+
572+
if len(draft_indices) == 0:
573+
return
574+
575+
device = draft_tensors.device
576+
577+
# Create index tensors
578+
draft_indices_tensor = torch.tensor(draft_indices,
579+
dtype=torch.long,
580+
pin_memory=True).to(
581+
device, non_blocking=True)
582+
target_indices_tensor = torch.tensor(target_indices,
583+
dtype=torch.long,
584+
pin_memory=True).to(
585+
device, non_blocking=True)
586+
587+
# Pre-slice draft tensors: [draft_length, batch_size]
588+
draft_slice = draft_tensors[0:draft_length]
589+
590+
# Gather all source data at once using single index_select kernel
591+
# Result shape: [draft_length, num_requests]
592+
gathered = draft_slice.index_select(1, draft_indices_tensor).to(
593+
torch.int32)
594+
595+
# Scatter to new_tokens using advanced indexing (single kernel)
596+
# Shape: [draft_length, num_requests] -> [seq_len, batch_size, beam_width]
597+
target_inputs.new_tokens[draft_position + 1:draft_position +
598+
draft_length + 1, target_indices_tensor,
599+
0] = gathered
600+
601+
# Scatter to next_draft_tokens using advanced indexing (single kernel)
602+
# Shape: [num_requests, draft_length] -> [batch_size, max_draft_len]
603+
target_inputs.next_draft_tokens[target_indices_tensor,
604+
draft_position:draft_position +
605+
draft_length] = gathered.t()
575606

576607
def _setup_draft_batch_and_resources(
577608
self,

0 commit comments

Comments
 (0)