Skip to content

Commit 21fa197

Browse files
committed
Enable autotuner warmup for drafting loops
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent e159a09 commit 21fa197

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -857,15 +857,16 @@ def _create_warmup_request(
857857
return None
858858

859859
num_extra_decoding_steps = self._get_num_extra_decoding_steps()
860-
if num_extra_decoding_steps > 0:
861-
return None # Disable autotuning for fused drafting loops for now.
862860

863861
num_ctx_tokens = num_tokens - num_gen_tokens
864862
num_ctx_requests = 0
865863
ctx_requests = []
866864
gen_requests = []
867865

868-
max_seq_len = self.max_seq_len - 1
866+
# For drafting loops, reduce max_seq_len to leave room for extra decoding steps
867+
max_seq_len = self.max_seq_len - 1 - num_extra_decoding_steps
868+
if max_seq_len < 1:
869+
return None # Not enough sequence length for drafting loop
869870
num_full_seqs = 0
870871
num_left_over_tokens = 0
871872

@@ -896,7 +897,8 @@ def _create_warmup_request(
896897
token_nums=ctx_token_nums,
897898
is_gen=False,
898899
max_num_draft_tokens=self.runtime_draft_len,
899-
use_mrope=self.use_mrope)
900+
use_mrope=self.use_mrope,
901+
num_extra_decoding_steps=num_extra_decoding_steps)
900902

901903
if spec_resource_manager is not None:
902904
spec_resource_manager.add_dummy_requests(
@@ -909,7 +911,8 @@ def _create_warmup_request(
909911
token_nums=[1] * num_gen_tokens,
910912
is_gen=True,
911913
max_num_draft_tokens=self.max_total_draft_tokens,
912-
use_mrope=self.use_mrope)
914+
use_mrope=self.use_mrope,
915+
num_extra_decoding_steps=num_extra_decoding_steps)
913916
if spec_resource_manager is not None:
914917
spec_resource_manager.add_dummy_requests(request_ids=list(
915918
range(num_ctx_requests, num_ctx_requests + num_gen_tokens)))

0 commit comments

Comments
 (0)