Skip to content

Commit 25f2434

Browse files
authored
fix: Set correct draft_token_nums to dummy requests for torch compilation with MTP (NVIDIA#3053)
Set correct draft_token_nums to dummy requests for torch compilation with MTP Signed-off-by: Hui Gao <[email protected]>
1 parent 268933b commit 25f2434

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,17 @@ def get_torch_compile_warmup_request(batch_size, num_tokens):
413413
num_tokens / kv_cache_manager.tokens_per_block):
414414
# Should only need (at most) one more page per request.
415415
is_gen = num_tokens == 1
416-
requests = kv_cache_manager.add_dummy_requests(list(
417-
range(batch_size)), [num_tokens] * batch_size,
418-
is_gen=is_gen)
416+
max_num_draft_tokens = self.spec_config.max_draft_tokens if self.spec_config is not None and is_gen else 0
417+
418+
requests = kv_cache_manager.add_dummy_requests(
419+
list(range(batch_size)), [num_tokens] * batch_size,
420+
is_gen=is_gen,
421+
max_num_draft_tokens=max_num_draft_tokens)
422+
423+
if spec_resource_manager is not None:
424+
spec_resource_manager.add_dummy_requests(
425+
request_ids=list(range(batch_size)))
426+
419427
result = ScheduledRequests()
420428
result.context_requests = []
421429
result.generation_requests = []

0 commit comments

Comments
 (0)