@@ -290,10 +290,13 @@ def _group_requests_by_strategy_key(
290290 }
291291
292292
293- def add_token (request : LlmRequest , new_tokens : torch .Tensor , * , beam : int , step : int = 0 ) -> int :
293+ def add_token (
294+ request : LlmRequest , new_tokens : list [list [list [int ]]], * , beam : int , step : int = 0
295+ ) -> int :
296+ # NB: Accessing nested lists faster than torch.Tensor or numpy.ndarray
294297 seq_slot = request .py_seq_slot
295298 assert seq_slot is not None
296- new_token = cast ( int , new_tokens [step ][seq_slot ][beam ]. item ())
299+ new_token = new_tokens [step ][seq_slot ][beam ]
297300 request .add_new_token (new_token , beam )
298301 return new_token
299302
@@ -700,7 +703,7 @@ def handle_logprobs(
700703 def _process_draft_tokens_greedy (
701704 self ,
702705 request : LlmRequest ,
703- new_tokens : torch . Tensor ,
706+ new_tokens : list [ list [ list [ int ]]] ,
704707 ) -> int :
705708 new_token = add_token (request , new_tokens , beam = self .BEAM )
706709 stop = self ._handle_stop_criteria (request , new_token )
@@ -722,7 +725,8 @@ def _process_draft_tokens_greedy(
722725 def _process_draft_tokens_tree (
723726 self ,
724727 request : LlmRequest ,
725- new_tokens : torch .Tensor ,
728+ new_tokens_tensor : torch .Tensor ,
729+ new_tokens_list : list [list [list [int ]]],
726730 spec_tree_manager : SpecTreeManager ,
727731 ) -> int :
728732 """Tree verification for draft token tree based speculative decoding.
@@ -757,7 +761,7 @@ def _process_draft_tokens_tree(
757761 # TODO: For the last layer of the dynamic tree, we need to resampling all the draft tokens.
758762 cur_layer_num_nodes = sum (spec_tree_manager .get_top_k_list (cur_draft_layer_idx ))
759763 for i in range (cur_layer_num_nodes ):
760- new_token = add_token (request , new_tokens , beam = 0 , step = i )
764+ new_token = add_token (request , new_tokens_list , beam = 0 , step = i )
761765 return 0
762766 else :
763767 # handle the target model request
@@ -767,7 +771,9 @@ def _process_draft_tokens_tree(
767771 eagle_paths = spec_tree_manager .get_eagle_paths (seq_slot )
768772
769773 all_draft_tokens = request .py_draft_tokens # [max_total_draft_tokens]
770- all_target_tokens = new_tokens [:, seq_slot , :].squeeze (- 1 ) # [max_total_draft_tokens]
774+ all_target_tokens = new_tokens_tensor [:, seq_slot , :].squeeze (
775+ - 1
776+ ) # [max_total_draft_tokens]
771777 assert all_target_tokens .shape [0 ] == spec_tree_manager .max_total_draft_tokens + 1
772778
773779 longest_accepted_len = 0
@@ -800,13 +806,15 @@ def _process_draft_tokens_tree(
800806 if longest_accepted_len == 0 :
801807 # No draft tokens are accepted.
802808 # Take the top-1 token of the first layer as the next new token.
803- new_token = add_token (request , new_tokens , beam = 0 , step = 0 )
809+ new_token = add_token (request , new_tokens_list , beam = 0 , step = 0 )
804810 return 0
805811 else :
806812 # Take the longest accepted path as the next new token.
807813 num_accepted_draft_tokens = 0
808814 for idx in eagle_paths [longest_match_path_idx ][:longest_accepted_len ]:
809- new_token = add_token (request , new_tokens , beam = 0 , step = cast (int , idx .item ()))
815+ new_token = add_token (
816+ request , new_tokens_list , beam = 0 , step = cast (int , idx .item ())
817+ )
810818 num_accepted_draft_tokens += 1
811819 if self ._handle_stop_criteria (request , new_token ):
812820 break
@@ -876,8 +884,10 @@ def _tree_sampling_batch(
876884 def _process_draft_tokens_rejection_sampling (
877885 self ,
878886 request : LlmRequest ,
879- new_tokens : torch .Tensor ,
887+ new_tokens_list : list [list [list [int ]]],
888+ new_tokens_tensor : torch .Tensor ,
880889 ) -> int :
890+ assert request .py_draft_logits is not None
881891 # FIXME: Passing a dummy vocab_size could result in unnecessary
882892 # filtering of vocab_size logits, out of vocab_size in
883893 # total. The 'sample' below should generally be avoided
@@ -893,7 +903,9 @@ def _process_draft_tokens_rejection_sampling(
893903 request .py_draft_logits ,
894904 generator = generator ,
895905 )
906+ assert draft_probs is not None
896907 target_probs = request .py_target_probs
908+ assert target_probs is not None
897909 d2t = getattr (request , "d2t" , None )
898910 if d2t is not None :
899911 vocab_d = draft_probs .shape [- 1 ]
@@ -927,26 +939,27 @@ def _process_draft_tokens_rejection_sampling(
927939 num_accepted = num_initially_accepted
928940 for i in range (num_accepted ):
929941 new_token = request .py_draft_tokens [i ]
930- new_tokens [i , request .seq_slot , self .BEAM ] = new_token
942+ new_tokens_tensor [i , request .seq_slot , self .BEAM ] = new_token
931943 request .add_new_token (new_token , self .BEAM )
932944 stop = self ._handle_stop_criteria (request , new_token )
933945 if stop :
934946 num_accepted = i + 1
935947 return num_accepted
936948 if sample_last :
937949 new_token = sample_rejected (draft_probs , target_probs , generator , num_accepted )
938- new_tokens [num_accepted , request .seq_slot , self .BEAM ] = new_token
950+ new_tokens_tensor [num_accepted , request .seq_slot , self .BEAM ] = new_token
939951 request .add_new_token (new_token , self .BEAM )
940952 else :
941- new_token = add_token (request , new_tokens , beam = self .BEAM , step = num_accepted )
953+ new_token = add_token (request , new_tokens_list , beam = self .BEAM , step = num_accepted )
942954 stop = self ._handle_stop_criteria (request , new_token )
943955
944956 return num_accepted
945957
946958 def process_draft_tokens (
947959 self ,
948960 request : LlmRequest ,
949- new_tokens : torch .Tensor ,
961+ new_tokens_tensor : torch .Tensor ,
962+ new_tokens_list : list [list [list [int ]]],
950963 resource_manager : Optional [ResourceManager ] = None ,
951964 ) -> int :
952965 if (
@@ -957,14 +970,19 @@ def process_draft_tokens(
957970 if spec_tree_manager is not None :
958971 num_accepted = self ._process_draft_tokens_tree (
959972 request ,
960- new_tokens = new_tokens ,
973+ new_tokens_tensor = new_tokens_tensor ,
974+ new_tokens_list = new_tokens_list ,
961975 spec_tree_manager = spec_tree_manager ,
962976 )
963977 else :
964- num_accepted = self ._process_draft_tokens_greedy (request , new_tokens = new_tokens )
978+ num_accepted = self ._process_draft_tokens_greedy (
979+ request , new_tokens = new_tokens_list
980+ )
965981 return num_accepted
966982 else :
967- return self ._process_draft_tokens_rejection_sampling (request , new_tokens )
983+ return self ._process_draft_tokens_rejection_sampling (
984+ request , new_tokens_list = new_tokens_list , new_tokens_tensor = new_tokens_tensor
985+ )
968986
969987 @override
970988 def update_requests (
@@ -976,15 +994,17 @@ def update_requests(
976994 if state .sampler_event :
977995 state .sampler_event .synchronize ()
978996
997+ assert state .host is not None
979998 new_tokens = state .host .new_tokens
999+ new_tokens_list = new_tokens .tolist ()
9801000
9811001 for req in state .scheduled_requests .context_requests :
9821002 if (
9831003 req .state == LlmRequestState .GENERATION_COMPLETE
9841004 or req .context_remaining_length != 0
9851005 ):
9861006 continue
987- new_token = add_token (req , new_tokens , beam = self .BEAM )
1007+ new_token = add_token (req , new_tokens_list , beam = self .BEAM )
9881008 self ._handle_stop_criteria (req , new_token )
9891009 self .handle_logprobs (req , state , beam = self .BEAM , count = 1 )
9901010 req .py_decoding_iter += 1
@@ -993,7 +1013,12 @@ def update_requests(
9931013 if req .state == LlmRequestState .GENERATION_COMPLETE :
9941014 continue
9951015 processed = 1
996- num_accepted = self .process_draft_tokens (req , new_tokens , resource_manager )
1016+ num_accepted = self .process_draft_tokens (
1017+ req ,
1018+ new_tokens_tensor = new_tokens ,
1019+ new_tokens_list = new_tokens_list ,
1020+ resource_manager = resource_manager ,
1021+ )
9971022 if get_draft_token_length (req ) > 0 :
9981023 req .py_num_accepted_draft_tokens = num_accepted
9991024 req .py_rewind_len = req .py_draft_pages_allocated - num_accepted
@@ -1911,7 +1936,7 @@ def update_requests_multiple_beams_or_drafting(
19111936 state : SampleStateTRTLLM ,
19121937 beam_width : int ,
19131938 ):
1914- new_tokens_host = state .host .new_tokens
1939+ new_tokens_host = state .host .new_tokens . tolist ()
19151940 finished_sum_host = state .host .finished_sum .tolist ()
19161941 finish_reasons = state .host .finish_reasons .flatten ().tolist ()
19171942 sequence_lengths_host_data = state .host .sequence_lengths .flatten ().tolist ()
0 commit comments