Skip to content

Commit 1f9cd59

Browse files
committed
[TRTLLM-9687][chore] Update testcase for write_finish_reasons and adjusted override of setup_sampler_step
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
1 parent 5b91375 commit 1f9cd59

File tree

3 files changed

+34
-4
lines changed

3 files changed

+34
-4
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,8 @@ def _create_store(self) -> Store:
887887
first_finish_reasons=int_tensor(
888888
self.CACHE_INDIRECTION_SHAPE[:-1],
889889
),
890+
max_lengths_tensor=int_tensor(self.max_num_sequences),
891+
end_ids=int_tensor(self.max_num_sequences),
890892
)
891893
else:
892894
return self.Store(
@@ -1330,15 +1332,16 @@ def _is_new_request(self, request: LlmRequest) -> bool:
13301332
or request.is_disagg_generation_transmission_complete
13311333
)
13321334

1333-
def setup_sampler_step(self, requests: ScheduledRequests):
1335+
@override
1336+
def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
13341337
"""Setup the sampler step for the requests
13351338
13361339
Args:
13371340
requests: list[LlmRequest]. The requests to setup the sampler step for
13381341
"""
13391342
if self._use_beam_search:
1340-
self._prepare_beam_search(requests.all_requests())
1341-
for request in requests.all_requests():
1343+
self._prepare_beam_search(scheduled_requests.all_requests())
1344+
for request in scheduled_requests.all_requests():
13421345
if self._is_new_request(request):
13431346
self.store.max_lengths_tensor[request.py_seq_slot].fill_(
13441347
min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens)

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,10 @@ class Store(TorchSampler.Store):
222222
next_draft_tokens: torch.Tensor
223223
new_tokens_lens: torch.Tensor
224224
max_total_draft_tokens: torch.Tensor
225-
finish_reasons: None = None # Necessary to satisfy the interface of TorchSampler.Store
225+
# Necessary to satisfy the interface of TorchSampler.Store
226+
finish_reasons: None = None
227+
end_ids: None = None
228+
max_lengths_tensor: None = None
226229

227230
def __post_init__(self):
228231
pass # finish_reasons has no size to compare against new_tokens in MTPSampler

tests/unittest/_torch/sampler/test_torch_sampler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,16 +691,40 @@ def setup(requests: list["RequestCase"]):
691691
seq_slots = torch.tensor(
692692
[req.request.py_seq_slot for req in requests], device="cuda", dtype=torch.int64
693693
)
694+
seq_lens = torch.tensor(
695+
[req.request.max_beam_num_tokens for req in requests], dtype=torch.int32, device="cuda"
696+
)
694697
new_tokens = torch.tensor(
695698
[req.new_tokens for req in requests], dtype=torch.int32, device="cuda"
696699
).T
697700
sampler.store.new_tokens[:, seq_slots, BEAM] = new_tokens
701+
max_seq_lens = torch.tensor(
702+
[
703+
min(
704+
sampler.max_seq_len, req.request.orig_prompt_len + req.request.py_max_new_tokens
705+
)
706+
for req in requests
707+
],
708+
dtype=torch.int32,
709+
device="cuda",
710+
)
711+
end_ids = torch.tensor(
712+
[
713+
req.request.py_end_id if req.request.py_end_id is not None else -1
714+
for req in requests
715+
],
716+
dtype=torch.int32,
717+
device="cuda",
718+
)
719+
sampler.store.max_lengths_tensor[seq_slots] = max_seq_lens
720+
sampler.store.end_ids[seq_slots] = end_ids
698721

699722
def run():
700723
sampler._write_finish_reasons(
701724
[req.request for req in requests],
702725
finish_reasons=sampler.store.finish_reasons,
703726
new_tokens=sampler.store.new_tokens,
727+
seq_lens=seq_lens,
704728
seq_slots=seq_slots,
705729
)
706730

0 commit comments

Comments
 (0)