Skip to content

Commit af6ff9b

Browse files
committed
[TRTLLM-9687][chore] Update testcase for write_finish_reasons and adjusted override of setup_sampler_step
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
1 parent 54b5acb commit af6ff9b

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,15 +1541,16 @@ def _is_new_request(self, request: LlmRequest) -> bool:
15411541
or request.is_disagg_generation_transmission_complete
15421542
)
15431543

1544-
def setup_sampler_step(self, requests: ScheduledRequests):
1544+
@override
1545+
def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
15451546
"""Setup the sampler step for the requests
15461547
15471548
Args:
15481549
requests: list[LlmRequest]. The requests to setup the sampler step for
15491550
"""
15501551
if self._use_beam_search:
1551-
self._prepare_beam_search(requests.all_requests())
1552-
for request in requests.all_requests():
1552+
self._prepare_beam_search(scheduled_requests.all_requests())
1553+
for request in scheduled_requests.all_requests():
15531554
if self._is_new_request(request):
15541555
self.store.max_lengths_tensor[request.py_seq_slot].fill_(
15551556
min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens)

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,10 @@ class Store(TorchSampler.Store):
222222
next_draft_tokens: torch.Tensor
223223
new_tokens_lens: torch.Tensor
224224
max_total_draft_tokens: torch.Tensor
225-
finish_reasons: None = None # Necessary to satisfy the interface of TorchSampler.Store
225+
# Necessary to satisfy the interface of TorchSampler.Store
226+
finish_reasons: None = None
227+
end_ids: None = None
228+
max_lengths_tensor: None = None
226229

227230
def __post_init__(self):
228231
pass # finish_reasons has no size to compare against new_tokens in MTPSampler

tests/unittest/_torch/sampler/test_torch_sampler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,16 +701,40 @@ def setup(requests: list["RequestCase"]):
701701
seq_slots = torch.tensor(
702702
[req.request.py_seq_slot for req in requests], device="cuda", dtype=torch.int64
703703
)
704+
seq_lens = torch.tensor(
705+
[req.request.max_beam_num_tokens for req in requests], dtype=torch.int32, device="cuda"
706+
)
704707
new_tokens = torch.tensor(
705708
[req.new_tokens for req in requests], dtype=torch.int32, device="cuda"
706709
).T
707710
sampler.store.new_tokens[:, seq_slots, BEAM] = new_tokens
711+
max_seq_lens = torch.tensor(
712+
[
713+
min(
714+
sampler.max_seq_len, req.request.orig_prompt_len + req.request.py_max_new_tokens
715+
)
716+
for req in requests
717+
],
718+
dtype=torch.int32,
719+
device="cuda",
720+
)
721+
end_ids = torch.tensor(
722+
[
723+
req.request.py_end_id if req.request.py_end_id is not None else -1
724+
for req in requests
725+
],
726+
dtype=torch.int32,
727+
device="cuda",
728+
)
729+
sampler.store.max_lengths_tensor[seq_slots] = max_seq_lens
730+
sampler.store.end_ids[seq_slots] = end_ids
708731

709732
def run():
710733
sampler._write_finish_reasons(
711734
[req.request for req in requests],
712735
finish_reasons=sampler.store.finish_reasons,
713736
new_tokens=sampler.store.new_tokens,
737+
seq_lens=seq_lens,
714738
seq_slots=seq_slots,
715739
)
716740

0 commit comments

Comments
 (0)