Skip to content

Commit 87eb508

Browse files
authored
[None][fix] restore list[list[list[int]]] in add_token (#8502)
Signed-off-by: ixlmar <[email protected]>
1 parent 85d5aa7 commit 87eb508

File tree

4 files changed

+54
-23
lines changed

4 files changed

+54
-23
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,13 @@ def _group_requests_by_strategy_key(
290290
}
291291

292292

293-
def add_token(request: LlmRequest, new_tokens: torch.Tensor, *, beam: int, step: int = 0) -> int:
293+
def add_token(
294+
request: LlmRequest, new_tokens: list[list[list[int]]], *, beam: int, step: int = 0
295+
) -> int:
296+
# NB: Accessing nested lists faster than torch.Tensor or numpy.ndarray
294297
seq_slot = request.py_seq_slot
295298
assert seq_slot is not None
296-
new_token = cast(int, new_tokens[step][seq_slot][beam].item())
299+
new_token = new_tokens[step][seq_slot][beam]
297300
request.add_new_token(new_token, beam)
298301
return new_token
299302

@@ -700,7 +703,7 @@ def handle_logprobs(
700703
def _process_draft_tokens_greedy(
701704
self,
702705
request: LlmRequest,
703-
new_tokens: torch.Tensor,
706+
new_tokens: list[list[list[int]]],
704707
) -> int:
705708
new_token = add_token(request, new_tokens, beam=self.BEAM)
706709
stop = self._handle_stop_criteria(request, new_token)
@@ -722,7 +725,8 @@ def _process_draft_tokens_greedy(
722725
def _process_draft_tokens_tree(
723726
self,
724727
request: LlmRequest,
725-
new_tokens: torch.Tensor,
728+
new_tokens_tensor: torch.Tensor,
729+
new_tokens_list: list[list[list[int]]],
726730
spec_tree_manager: SpecTreeManager,
727731
) -> int:
728732
"""Tree verification for draft token tree based speculative decoding.
@@ -757,7 +761,7 @@ def _process_draft_tokens_tree(
757761
# TODO: For the last layer of the dynamic tree, we need to resampling all the draft tokens.
758762
cur_layer_num_nodes = sum(spec_tree_manager.get_top_k_list(cur_draft_layer_idx))
759763
for i in range(cur_layer_num_nodes):
760-
new_token = add_token(request, new_tokens, beam=0, step=i)
764+
new_token = add_token(request, new_tokens_list, beam=0, step=i)
761765
return 0
762766
else:
763767
# handle the target model request
@@ -767,7 +771,9 @@ def _process_draft_tokens_tree(
767771
eagle_paths = spec_tree_manager.get_eagle_paths(seq_slot)
768772

769773
all_draft_tokens = request.py_draft_tokens # [max_total_draft_tokens]
770-
all_target_tokens = new_tokens[:, seq_slot, :].squeeze(-1) # [max_total_draft_tokens]
774+
all_target_tokens = new_tokens_tensor[:, seq_slot, :].squeeze(
775+
-1
776+
) # [max_total_draft_tokens]
771777
assert all_target_tokens.shape[0] == spec_tree_manager.max_total_draft_tokens + 1
772778

773779
longest_accepted_len = 0
@@ -800,13 +806,15 @@ def _process_draft_tokens_tree(
800806
if longest_accepted_len == 0:
801807
# No draft tokens are accepted.
802808
# Take the top-1 token of the first layer as the next new token.
803-
new_token = add_token(request, new_tokens, beam=0, step=0)
809+
new_token = add_token(request, new_tokens_list, beam=0, step=0)
804810
return 0
805811
else:
806812
# Take the longest accepted path as the next new token.
807813
num_accepted_draft_tokens = 0
808814
for idx in eagle_paths[longest_match_path_idx][:longest_accepted_len]:
809-
new_token = add_token(request, new_tokens, beam=0, step=cast(int, idx.item()))
815+
new_token = add_token(
816+
request, new_tokens_list, beam=0, step=cast(int, idx.item())
817+
)
810818
num_accepted_draft_tokens += 1
811819
if self._handle_stop_criteria(request, new_token):
812820
break
@@ -876,8 +884,10 @@ def _tree_sampling_batch(
876884
def _process_draft_tokens_rejection_sampling(
877885
self,
878886
request: LlmRequest,
879-
new_tokens: torch.Tensor,
887+
new_tokens_list: list[list[list[int]]],
888+
new_tokens_tensor: torch.Tensor,
880889
) -> int:
890+
assert request.py_draft_logits is not None
881891
# FIXME: Passing a dummy vocab_size could result in unnecessary
882892
# filtering of vocab_size logits, out of vocab_size in
883893
# total. The 'sample' below should generally be avoided
@@ -893,7 +903,9 @@ def _process_draft_tokens_rejection_sampling(
893903
request.py_draft_logits,
894904
generator=generator,
895905
)
906+
assert draft_probs is not None
896907
target_probs = request.py_target_probs
908+
assert target_probs is not None
897909
d2t = getattr(request, "d2t", None)
898910
if d2t is not None:
899911
vocab_d = draft_probs.shape[-1]
@@ -927,26 +939,27 @@ def _process_draft_tokens_rejection_sampling(
927939
num_accepted = num_initially_accepted
928940
for i in range(num_accepted):
929941
new_token = request.py_draft_tokens[i]
930-
new_tokens[i, request.seq_slot, self.BEAM] = new_token
942+
new_tokens_tensor[i, request.seq_slot, self.BEAM] = new_token
931943
request.add_new_token(new_token, self.BEAM)
932944
stop = self._handle_stop_criteria(request, new_token)
933945
if stop:
934946
num_accepted = i + 1
935947
return num_accepted
936948
if sample_last:
937949
new_token = sample_rejected(draft_probs, target_probs, generator, num_accepted)
938-
new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token
950+
new_tokens_tensor[num_accepted, request.seq_slot, self.BEAM] = new_token
939951
request.add_new_token(new_token, self.BEAM)
940952
else:
941-
new_token = add_token(request, new_tokens, beam=self.BEAM, step=num_accepted)
953+
new_token = add_token(request, new_tokens_list, beam=self.BEAM, step=num_accepted)
942954
stop = self._handle_stop_criteria(request, new_token)
943955

944956
return num_accepted
945957

946958
def process_draft_tokens(
947959
self,
948960
request: LlmRequest,
949-
new_tokens: torch.Tensor,
961+
new_tokens_tensor: torch.Tensor,
962+
new_tokens_list: list[list[list[int]]],
950963
resource_manager: Optional[ResourceManager] = None,
951964
) -> int:
952965
if (
@@ -957,14 +970,19 @@ def process_draft_tokens(
957970
if spec_tree_manager is not None:
958971
num_accepted = self._process_draft_tokens_tree(
959972
request,
960-
new_tokens=new_tokens,
973+
new_tokens_tensor=new_tokens_tensor,
974+
new_tokens_list=new_tokens_list,
961975
spec_tree_manager=spec_tree_manager,
962976
)
963977
else:
964-
num_accepted = self._process_draft_tokens_greedy(request, new_tokens=new_tokens)
978+
num_accepted = self._process_draft_tokens_greedy(
979+
request, new_tokens=new_tokens_list
980+
)
965981
return num_accepted
966982
else:
967-
return self._process_draft_tokens_rejection_sampling(request, new_tokens)
983+
return self._process_draft_tokens_rejection_sampling(
984+
request, new_tokens_list=new_tokens_list, new_tokens_tensor=new_tokens_tensor
985+
)
968986

969987
@override
970988
def update_requests(
@@ -976,15 +994,17 @@ def update_requests(
976994
if state.sampler_event:
977995
state.sampler_event.synchronize()
978996

997+
assert state.host is not None
979998
new_tokens = state.host.new_tokens
999+
new_tokens_list = new_tokens.tolist()
9801000

9811001
for req in state.scheduled_requests.context_requests:
9821002
if (
9831003
req.state == LlmRequestState.GENERATION_COMPLETE
9841004
or req.context_remaining_length != 0
9851005
):
9861006
continue
987-
new_token = add_token(req, new_tokens, beam=self.BEAM)
1007+
new_token = add_token(req, new_tokens_list, beam=self.BEAM)
9881008
self._handle_stop_criteria(req, new_token)
9891009
self.handle_logprobs(req, state, beam=self.BEAM, count=1)
9901010
req.py_decoding_iter += 1
@@ -993,7 +1013,12 @@ def update_requests(
9931013
if req.state == LlmRequestState.GENERATION_COMPLETE:
9941014
continue
9951015
processed = 1
996-
num_accepted = self.process_draft_tokens(req, new_tokens, resource_manager)
1016+
num_accepted = self.process_draft_tokens(
1017+
req,
1018+
new_tokens_tensor=new_tokens,
1019+
new_tokens_list=new_tokens_list,
1020+
resource_manager=resource_manager,
1021+
)
9971022
if get_draft_token_length(req) > 0:
9981023
req.py_num_accepted_draft_tokens = num_accepted
9991024
req.py_rewind_len = req.py_draft_pages_allocated - num_accepted
@@ -1911,7 +1936,7 @@ def update_requests_multiple_beams_or_drafting(
19111936
state: SampleStateTRTLLM,
19121937
beam_width: int,
19131938
):
1914-
new_tokens_host = state.host.new_tokens
1939+
new_tokens_host = state.host.new_tokens.tolist()
19151940
finished_sum_host = state.host.finished_sum.tolist()
19161941
finish_reasons = state.host.finish_reasons.flatten().tolist()
19171942
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def update_requests(
256256
assert isinstance(state, SampleStateMTP)
257257

258258
state.sampler_event.synchronize()
259-
new_tokens = state.host.new_tokens
259+
new_tokens = state.host.new_tokens.tolist()
260260
new_tokens_lens_list = state.host.new_tokens_lens.tolist()
261261
next_draft_tokens_list = state.host.next_draft_tokens.tolist()
262262
beam_idx = self.BEAM

tests/unittest/_torch/speculative/test_draft_token_tree_verification.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ def run_test(eagle_model_dir, max_seq_len, beam_width, use_dynamic_tree,
4545
max_beam_width=beam_width,
4646
))
4747

48+
input_new_tokens_list = input_new_tokens.tolist()
4849
num_accepted_draft_tokens = torch_sampler._process_draft_tokens_tree(
4950
request=input_request,
50-
new_tokens=input_new_tokens,
51+
new_tokens_tensor=input_new_tokens,
52+
new_tokens_list=input_new_tokens_list,
5153
spec_tree_manager=spec_tree_manager)
5254

5355
print(f"num_accepted_draft_tokens: {num_accepted_draft_tokens}")

tests/unittest/_torch/speculative/test_torch_rejection_sampling.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from typing import cast
23

34
import numpy as np
45
import torch
@@ -24,8 +25,11 @@ def test_get_rejected_indices():
2425
sampled_regular = []
2526
for _ in range(num_iter):
2627
draft_tokens = [
27-
torch.multinomial(draft_probs, num_samples=1,
28-
generator=generator).item()
28+
cast(
29+
int,
30+
torch.multinomial(draft_probs,
31+
num_samples=1,
32+
generator=generator).item())
2933
]
3034
rejected_indices = get_rejected_indices(draft_probs, target_probs,
3135
generator, draft_tokens)

0 commit comments

Comments
 (0)