1919 SpeculativeDecodingMode )
2020from tensorrt_llm .mapping import Mapping
2121
22- from .llm_request import LlmRequest , LlmRequestState
22+ from .llm_request import LlmRequest , LlmRequestState , LogProbs
2323from .scheduler import ScheduledRequests
2424
2525
@@ -29,6 +29,9 @@ class DecoderState:
2929
3030 logits : torch .Tensor = None
3131
32+ # Set when decode_async() has evaluated these to avoid computing again in update_requests()
33+ log_probs : list [LogProbs ] | None = None
34+
3235 new_tensors_device : dict [str , torch .Tensor ] = None
3336 new_tensors_host : dict [str , torch .Tensor ] = None
3437
@@ -66,9 +69,13 @@ def update_requests(self, decoder_state: DecoderState) -> None:
6669 assert (not scheduled_requests .generation_requests )
6770 for idx , request in enumerate (scheduled_requests .context_requests ):
6871 request .state = LlmRequestState .GENERATION_COMPLETE
69- #NOTE: This is a hack: set finish reason manually and set the beam 0
72+ # NOTE: This is a hack: set finish reason manually and set the beam 0
7073 request .set_finished_reason (FinishReason .LENGTH , 0 )
71- request .context_logits = decoder_state .logits [idx ]
74+ logits = decoder_state .logits [idx ]
75+ if logits .ndim == 1 :
76+ # For BERT: Add vocab_size axis to be compatible with LogitsStorage.
77+ logits = logits .unsqueeze (- 1 )
78+ request .py_result .append_context_logits (logits )
7279
7380
7481def top_k_sampling_batch (logits , top_k = 50 ):
@@ -158,10 +165,7 @@ def decode_single_request(request: LlmRequest, logits):
158165 logits , request .sampling_config .top_k [0 ])
159166 else :
160167 next_tokens , log_probs = greedy_search_sampling_batch (logits )
161- # TODO: enable these lines when log_probs is needed
162- # request.log_probs_async = log_probs
163- # request.set_cum_log_prob(request.cum_log_probs[0] + log_probs[0].item(), 0)
164- return next_tokens
168+ return next_tokens , log_probs
165169
166170
167171class TorchDecoder (Decoder ):
@@ -222,23 +226,51 @@ def update_requests(self, decoder_state: DecoderState) -> None:
222226 "new_tokens_host" ].tolist ()
223227 scheduled_requests = decoder_state .scheduled_requests
224228
225- idx = 0
229+ request_idx = 0
230+ token_idx = 0
226231 beam_idx = 0
232+
233+ def advance_idx (num_tokens = 1 ):
234+ nonlocal request_idx , token_idx
235+ request_idx += 1
236+ token_idx += num_tokens
237+
238+ def handle_logits (request : LlmRequest , count = 1 ):
239+ if decoder_state .logits is None :
240+ return
241+ if not request .py_return_generation_logits and not request .py_return_log_probs :
242+ return
243+
244+ current_slice = slice (token_idx , token_idx + count )
245+ current_logits = decoder_state .logits [current_slice ]
246+
247+ request .py_result .append_generation_logits (current_logits )
248+
249+ if not request .py_return_log_probs :
250+ return
251+
252+ if decoder_state .log_probs :
253+ log_probs = decoder_state .log_probs [request_idx ]
254+ else :
255+ _ , log_probs = greedy_search_sampling_batch (current_logits )
256+ request .py_result .append_log_probs ([log_probs .tolist ()])
257+
227258 for request in scheduled_requests .context_requests :
228259 if request .get_context_remaining_length () != 0 :
229- idx += 1
260+ advance_idx ()
230261 continue
231262
232263 if request .state != LlmRequestState .GENERATION_COMPLETE :
233- new_token = new_tokens_list [idx ]
264+ new_token = new_tokens_list [token_idx ]
234265 num_tokens = request .add_new_token (new_token , beam_idx )
235266 self ._handle_stop_criteria (request , new_token , num_tokens ,
236267 beam_idx )
268+ handle_logits (request )
237269 request .py_decoding_iter += 1
238- idx += 1
270+ advance_idx ()
239271
240272 if hasattr (scheduled_requests , 'chunked_requests' ):
241- idx += len (scheduled_requests .chunked_requests )
273+ request_idx += len (scheduled_requests .chunked_requests )
242274
243275 extend_requests = []
244276 generation_requests = []
@@ -249,41 +281,41 @@ def update_requests(self, decoder_state: DecoderState) -> None:
249281 generation_requests .append (request )
250282
251283 for request in extend_requests :
252- num_accepted = 0
253284 if request .state != LlmRequestState .GENERATION_COMPLETE :
254- new_token = new_tokens_list [idx ]
285+ new_token = new_tokens_list [token_idx ]
255286 num_tokens = request .add_new_token (new_token , beam_idx )
256287 self ._handle_stop_criteria (request , new_token , num_tokens ,
257288 beam_idx )
258- request .py_decoding_iter += 1
259289
260290 # Accept draft tokens (if we have any) if and only if they match the new
261291 # token exactly.
262- for i in range ( len ( request . py_draft_tokens )):
263- draft_token = request .py_draft_tokens [ i ]
292+ num_accepted = 0
293+ for draft_token in request .py_draft_tokens :
264294 if draft_token != new_token :
265295 # Reject.
266296 break
267-
268297 num_accepted += 1
269- new_token = new_tokens_list [idx + i + 1 ]
298+ new_token = new_tokens_list [token_idx + num_accepted ]
270299 num_tokens = request .add_new_token (new_token , beam_idx )
271300
272301 if self ._handle_stop_criteria (request , new_token ,
273302 num_tokens , beam_idx ):
274303 break
275- request .py_num_accepted_draft_tokens = num_accepted
276- request .py_rewind_len = request .py_draft_pages_allocated - num_accepted
277- idx += len (request .py_draft_tokens ) + 1
304+ handle_logits (request , num_accepted )
305+ request .py_decoding_iter += 1
306+ request .py_num_accepted_draft_tokens = num_accepted
307+ request .py_rewind_len = request .py_draft_pages_allocated - num_accepted
308+ advance_idx (len (request .py_draft_tokens ) + 1 )
278309
279310 for request in generation_requests :
280311 if request .state != LlmRequestState .GENERATION_COMPLETE :
281- new_token = new_tokens_list [idx ]
312+ new_token = new_tokens_list [token_idx ]
282313 num_tokens = request .add_new_token (new_token , beam_idx )
283314 self ._handle_stop_criteria (request , new_token , num_tokens ,
284315 beam_idx )
316+ handle_logits (request )
285317 request .py_decoding_iter += 1
286- idx += 1
318+ advance_idx ()
287319
288320 def _mixed_decode (self , scheduled_requests : ScheduledRequests ,
289321 model_outputs ) -> DecoderState :
@@ -292,23 +324,33 @@ def _mixed_decode(self, scheduled_requests: ScheduledRequests,
292324 state = DecoderState (
293325 scheduled_requests = scheduled_requests ,
294326 logits = logits ,
327+ log_probs = [],
295328 )
296329
297330 new_tokens_device_array = []
298331
299332 idx = 0
300333
301334 for request in scheduled_requests .context_requests :
335+ assert not request .py_return_context_logits , "Return context logits not supported"
302336 token_logits = logits [idx :idx + 1 , :]
303- new_token = decode_single_request (request , token_logits )
337+ new_token , log_probs = decode_single_request (request , token_logits )
304338 new_tokens_device_array .append (new_token )
339+ log_probs = [log_probs .tolist ()
340+ ] if request .py_return_log_probs else None
341+ state .log_probs .append (log_probs ) # Currently always beam_width=1
305342 idx += 1
306343
307344 for request in scheduled_requests .generation_requests :
345+ if request .state == LlmRequestState .GENERATION_COMPLETE :
346+ continue
308347 assert request .py_draft_tokens is None , "Speculative decoding not supported in SeparateDecoder."
309348 token_logits = logits [idx :idx + 1 , :]
310- new_token = decode_single_request (request , token_logits )
349+ new_token , log_probs = decode_single_request (request , token_logits )
311350 new_tokens_device_array .append (new_token )
351+ log_probs = [log_probs .tolist ()
352+ ] if request .py_return_log_probs else None
353+ state .log_probs .append (log_probs ) # Currently always beam_width=1
312354 idx += 1
313355
314356 new_tokens_device = torch .cat (new_tokens_device_array )
@@ -351,6 +393,13 @@ def update_one_request(self, request: LlmRequest,
351393 output_token_idx = request .output_token_idx
352394 new_token = new_tokens_list [output_token_idx ]
353395 num_tokens = request .add_new_token (new_token , beam_idx )
396+
397+ current_logits = logits [output_token_idx ].unsqueeze (0 )
398+ request .py_result .append_generation_logits (current_logits )
399+ if request .py_return_log_probs :
400+ _ , log_probs = greedy_search_sampling_batch (current_logits )
401+ request .py_result .append_log_probs ([log_probs .tolist ()])
402+
354403 self ._handle_stop_criteria (request , new_token , num_tokens , beam_idx )
355404 if request .state != LlmRequestState .GENERATION_COMPLETE :
356405 request .py_decoding_iter += 1
0 commit comments