Skip to content

Commit 570094a

Browse files
committed
[TRTLLM-9687][chore] Prevent draft requests from changing max_lengths_tensor
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
1 parent bfa324d commit 570094a

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,9 +1327,13 @@ def _process_draft_tokens_tree(
13271327
return num_accepted_draft_tokens - 1
13281328

13291329
def _is_new_request(self, request: LlmRequest) -> bool:
1330-
return not request.is_finished and (
1331-
(request.is_context_init_state and request.is_last_context_chunk)
1332-
or request.is_disagg_generation_transmission_complete
1330+
return (
1331+
not request.is_finished
1332+
and not request.py_is_draft
1333+
and (
1334+
(request.is_context_init_state and request.is_last_context_chunk)
1335+
or request.is_disagg_generation_transmission_complete
1336+
)
13331337
)
13341338

13351339
@override

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ class Store(TorchSampler.Store):
230230
def __post_init__(self):
231231
pass # finish_reasons has no size to compare against new_tokens in MTPSampler
232232

233+
def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
234+
# MTPSampler does not need to setup additional buffers before the sampler step
235+
pass
236+
233237
def __init__(self, args: TorchSampler.Args, *, nextn: int):
234238
self.mapping = None
235239
self.draft_len = nextn

0 commit comments

Comments
 (0)