Skip to content

Commit 3d3561e

Browse files
committed
[TRTLLM-9687][chore] Prevent draft requests from changing max_lengths_tensor
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
1 parent af6ff9b commit 3d3561e

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,9 +1536,13 @@ def _process_draft_tokens_tree(
15361536
return num_accepted_draft_tokens - 1
15371537

15381538
def _is_new_request(self, request: LlmRequest) -> bool:
1539-
return not request.is_finished and (
1540-
(request.is_context_init_state and request.is_last_context_chunk)
1541-
or request.is_disagg_generation_transmission_complete
1539+
return (
1540+
not request.is_finished
1541+
and not request.py_is_draft
1542+
and (
1543+
(request.is_context_init_state and request.is_last_context_chunk)
1544+
or request.is_disagg_generation_transmission_complete
1545+
)
15421546
)
15431547

15441548
@override

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ class Store(TorchSampler.Store):
230230
def __post_init__(self):
231231
pass # finish_reasons has no size to compare against new_tokens in MTPSampler
232232

233+
def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
234+
# MTPSampler does not need to setup additional buffers before the sampler step
235+
pass
236+
233237
def __init__(self, args: TorchSampler.Args, *, nextn: int):
234238
self.mapping = None
235239
self.draft_len = nextn

0 commit comments

Comments
 (0)