Skip to content

Commit c59aa8b

Browse files
authored
[TRTLLM-9962][feat] Some optimizations for two-model spec dec (#10208)
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent ae6d576 commit c59aa8b

File tree

4 files changed

+74
-42
lines changed

4 files changed

+74
-42
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -895,8 +895,6 @@ def _create_warmup_request(
895895
return None
896896

897897
num_extra_decoding_steps = self._get_num_extra_decoding_steps()
898-
if num_extra_decoding_steps > 0:
899-
return None # Disable autotuning for fused drafting loops for now.
900898

901899
if num_gen_requests > self.batch_size:
902900
return None
@@ -909,7 +907,10 @@ def _create_warmup_request(
909907
ctx_requests = []
910908
gen_requests = []
911909

912-
max_seq_len = self.max_seq_len - 1
910+
# For drafting loops, reduce max_seq_len to leave room for extra decoding steps
911+
max_seq_len = self.max_seq_len - 1 - num_extra_decoding_steps
912+
if max_seq_len < 1:
913+
return None # Not enough sequence length for drafting loop
913914
num_full_seqs = 0
914915
num_left_over_tokens = 0
915916

@@ -954,7 +955,8 @@ def _create_warmup_request(
954955
token_nums=ctx_token_nums,
955956
is_gen=False,
956957
max_num_draft_tokens=self.runtime_draft_len,
957-
use_mrope=self.use_mrope)
958+
use_mrope=self.use_mrope,
959+
num_extra_decoding_steps=num_extra_decoding_steps)
958960

959961
if spec_resource_manager is not None:
960962
spec_resource_manager.add_dummy_requests(
@@ -1546,7 +1548,6 @@ def _prepare_incremental_update_metadata(
15461548

15471549
return lora_params
15481550

1549-
@torch.compile(options={"max-autotune": True})
15501551
def _update_draft_input_tensors(self,
15511552
num_accepted_tokens_device: torch.Tensor,
15521553
new_tokens_device: torch.Tensor,
@@ -1671,7 +1672,6 @@ def _apply_incremental_update_draft(
16711672

16721673
return inputs, self.gather_ids_cuda[:num_generation_tokens]
16731674

1674-
@torch.compile(options={"max-autotune": True})
16751675
def _update_target_input_tensors(
16761676
self, num_accepted_tokens_device: torch.Tensor,
16771677
new_tokens_device: torch.Tensor,

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1708,7 +1708,6 @@ def _executor_loop_overlap(self):
17081708
self.iter_counter += 1
17091709

17101710
@nvtx_range("_accept_draft_tokens")
1711-
@torch.compile(options={"max-autotune": True})
17121711
def _accept_draft_tokens(
17131712
self, scheduled_batch: ScheduledRequests,
17141713
target_outputs: SampleStateTensors,

tensorrt_llm/_torch/speculative/drafting_loops.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -120,24 +120,27 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor,
120120

121121
new_draft_tokens = [self.sample(logits)]
122122
draft_logits = [logits]
123-
with save_metadata_state(attn_metadata, spec_metadata):
124-
batch_size = attn_metadata.num_seqs
125-
126-
new_position_ids = self.prepare_for_generation(
127-
attn_metadata, spec_metadata, position_ids)
128-
for i in range(self.max_draft_len - 1):
129-
logits = self.draft_model.forward(
130-
input_ids=new_draft_tokens[-1],
131-
position_ids=new_position_ids,
132-
attn_metadata=attn_metadata,
133-
spec_metadata=spec_metadata)
134-
new_draft_tokens.append(self.sample(logits))
135-
draft_logits.append(logits)
136-
new_position_ids += 1
137-
attn_metadata.kv_lens_cuda[:batch_size] += 1
138-
if i == 0 and isinstance(spec_metadata, Eagle3SpecMetadata):
139-
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
140-
spec_metadata.hidden_states_write_indices[:batch_size])
123+
if self.max_draft_len > 1:
124+
is_eagle3 = isinstance(spec_metadata, Eagle3SpecMetadata)
125+
with save_metadata_state(attn_metadata, spec_metadata):
126+
batch_size = attn_metadata.num_seqs
127+
128+
new_position_ids = self.prepare_for_generation(
129+
attn_metadata, spec_metadata, position_ids)
130+
for i in range(self.max_draft_len - 1):
131+
logits = self.draft_model.forward(
132+
input_ids=new_draft_tokens[-1],
133+
position_ids=new_position_ids,
134+
attn_metadata=attn_metadata,
135+
spec_metadata=spec_metadata)
136+
new_draft_tokens.append(self.sample(logits))
137+
draft_logits.append(logits)
138+
new_position_ids += 1
139+
attn_metadata.kv_lens_cuda[:batch_size] += 1
140+
if i == 0 and is_eagle3:
141+
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
142+
spec_metadata.
143+
hidden_states_write_indices[:batch_size])
141144

142145
return {
143146
"new_draft_tokens": torch.stack(new_draft_tokens),
@@ -153,7 +156,6 @@ def sample(self, logits: torch.Tensor) -> torch.Tensor:
153156

154157
return tokens
155158

156-
@torch.compile(options={'max-autotune': True})
157159
def prepare_for_generation(self, attn_metadata: AttentionMetadata,
158160
spec_metadata: SpecMetadata,
159161
position_ids: torch.Tensor) -> torch.Tensor:

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -576,22 +576,53 @@ def _update_draft_tokens_for_target_inputs(
576576
if target_inputs.next_draft_tokens is None:
577577
return
578578

579-
if draft_tensors is not None:
580-
for req_idx, request in enumerate(draft_batch.all_requests()):
581-
target_req = self.req_id_to_old_request[request.py_request_id]
582-
if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
583-
# Skip prefill requests
584-
continue
585-
# Get the index of the draft/target tokens in the device tensor
586-
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
587-
target_idx = target_req.py_seq_slot
588-
target_inputs.new_tokens[draft_position + 1:draft_position +
589-
draft_length + 1, target_idx,
590-
0] = draft_tensors[0:draft_length,
591-
draft_idx]
592-
target_inputs.next_draft_tokens[
593-
target_idx, draft_position:draft_position +
594-
draft_length] = draft_tensors[0:draft_length, draft_idx]
579+
draft_indices = []
580+
target_indices = []
581+
for req_idx, request in enumerate(draft_batch.all_requests()):
582+
target_req = self.req_id_to_old_request[request.py_request_id]
583+
if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
584+
# Skip prefill requests
585+
continue
586+
# Get the index of the draft/target tokens in the device tensor
587+
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
588+
target_idx = target_req.py_seq_slot
589+
draft_indices.append(draft_idx)
590+
target_indices.append(target_idx)
591+
592+
if len(draft_indices) == 0:
593+
return
594+
595+
device = draft_tensors.device
596+
597+
# Create index tensors
598+
draft_indices_tensor = torch.tensor(draft_indices,
599+
dtype=torch.long,
600+
pin_memory=True).to(
601+
device, non_blocking=True)
602+
target_indices_tensor = torch.tensor(target_indices,
603+
dtype=torch.long,
604+
pin_memory=True).to(
605+
device, non_blocking=True)
606+
607+
# Pre-slice draft tensors: [draft_length, batch_size]
608+
draft_slice = draft_tensors[0:draft_length]
609+
610+
# Gather all source data at once using single index_select kernel
611+
# Result shape: [draft_length, num_requests]
612+
gathered = draft_slice.index_select(1, draft_indices_tensor).to(
613+
torch.int32)
614+
615+
# Scatter to new_tokens using advanced indexing (single kernel)
616+
# Shape: [draft_length, num_requests] -> [seq_len, batch_size, beam_width]
617+
target_inputs.new_tokens[draft_position + 1:draft_position +
618+
draft_length + 1, target_indices_tensor,
619+
0] = gathered
620+
621+
# Scatter to next_draft_tokens using advanced indexing (single kernel)
622+
# Shape: [num_requests, draft_length] -> [batch_size, max_draft_len]
623+
target_inputs.next_draft_tokens[target_indices_tensor,
624+
draft_position:draft_position +
625+
draft_length] = gathered.t()
595626

596627
def _setup_draft_batch_and_resources(
597628
self,

0 commit comments

Comments
 (0)