Skip to content

Commit 11d8f4a

Browse files
committed
None: Enable autotuner warmup for CDL
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 237fd0e commit 11d8f4a

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -846,10 +846,6 @@ def _create_warmup_request(
846846
if num_tokens > self.max_num_tokens or num_tokens > available_tokens:
847847
return None
848848

849-
num_extra_decoding_steps = self._get_num_extra_decoding_steps()
850-
if num_extra_decoding_steps > 0:
851-
return None # Disable autotuning for fused drafting loops for now.
852-
853849
num_ctx_tokens = num_tokens - num_gen_tokens
854850
num_ctx_requests = 0
855851
ctx_requests = []
@@ -868,10 +864,16 @@ def _create_warmup_request(
868864
if num_ctx_requests + num_gen_tokens > self.batch_size:
869865
return None # Not enough batch size to fill the request
870866

867+
# For fused drafting loops, each generation request needs extra blocks
868+
# for the tokens that will be generated during the loop
869+
num_extra_decoding_steps = self._get_num_extra_decoding_steps()
870+
tokens_per_gen = 1 + num_extra_decoding_steps
871+
blocks_per_gen = math.ceil(tokens_per_gen /
872+
kv_cache_manager.tokens_per_block)
871873
blocks_to_use = num_full_seqs * math.ceil(
872874
max_seq_len / kv_cache_manager.tokens_per_block) + math.ceil(
873-
num_left_over_tokens /
874-
kv_cache_manager.tokens_per_block) + num_gen_tokens
875+
num_left_over_tokens / kv_cache_manager.tokens_per_block
876+
) + num_gen_tokens * blocks_per_gen
875877

876878
if blocks_to_use > available_blocks:
877879
return None
@@ -899,7 +901,8 @@ def _create_warmup_request(
899901
token_nums=[1] * num_gen_tokens,
900902
is_gen=True,
901903
max_num_draft_tokens=self.max_total_draft_tokens,
902-
use_mrope=self.use_mrope)
904+
use_mrope=self.use_mrope,
905+
num_extra_decoding_steps=num_extra_decoding_steps)
903906
if spec_resource_manager is not None:
904907
spec_resource_manager.add_dummy_requests(request_ids=list(
905908
range(num_ctx_requests, num_ctx_requests + num_gen_tokens)))

0 commit comments

Comments
 (0)