Skip to content

Commit 5794420

Browse files
feat: return logits in PyTorch flow (NVIDIA#3221)
Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com> Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
1 parent 991939a commit 5794420

File tree

11 files changed

+489
-53
lines changed

11 files changed

+489
-53
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1096,7 +1096,7 @@ def forward(
10961096
attn_metadata,
10971097
True,
10981098
)
1099-
# get accepetd tokens and next draft tokens
1099+
# get accepted tokens and next draft tokens
11001100
return self.mtp_worker(
11011101
input_ids=input_ids,
11021102
position_ids=position_ids,

tensorrt_llm/_torch/pyexecutor/decoder.py

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
SpeculativeDecodingMode)
2020
from tensorrt_llm.mapping import Mapping
2121

22-
from .llm_request import LlmRequest, LlmRequestState
22+
from .llm_request import LlmRequest, LlmRequestState, LogProbs
2323
from .scheduler import ScheduledRequests
2424

2525

@@ -29,6 +29,9 @@ class DecoderState:
2929

3030
logits: torch.Tensor = None
3131

32+
# Set when decode_async() has evaluated these to avoid computing again in update_requests()
33+
log_probs: list[LogProbs] | None = None
34+
3235
new_tensors_device: dict[str, torch.Tensor] = None
3336
new_tensors_host: dict[str, torch.Tensor] = None
3437

@@ -66,9 +69,13 @@ def update_requests(self, decoder_state: DecoderState) -> None:
6669
assert (not scheduled_requests.generation_requests)
6770
for idx, request in enumerate(scheduled_requests.context_requests):
6871
request.state = LlmRequestState.GENERATION_COMPLETE
69-
#NOTE: This is a hack: set finish reason manually and set the beam 0
72+
# NOTE: This is a hack: set finish reason manually and set the beam 0
7073
request.set_finished_reason(FinishReason.LENGTH, 0)
71-
request.context_logits = decoder_state.logits[idx]
74+
logits = decoder_state.logits[idx]
75+
if logits.ndim == 1:
76+
# For BERT: Add vocab_size axis to be compatible with LogitsStorage.
77+
logits = logits.unsqueeze(-1)
78+
request.py_result.append_context_logits(logits)
7279

7380

7481
def top_k_sampling_batch(logits, top_k=50):
@@ -158,10 +165,7 @@ def decode_single_request(request: LlmRequest, logits):
158165
logits, request.sampling_config.top_k[0])
159166
else:
160167
next_tokens, log_probs = greedy_search_sampling_batch(logits)
161-
# TODO: enable these lines when log_probs is needed
162-
# request.log_probs_async = log_probs
163-
# request.set_cum_log_prob(request.cum_log_probs[0] + log_probs[0].item(), 0)
164-
return next_tokens
168+
return next_tokens, log_probs
165169

166170

167171
class TorchDecoder(Decoder):
@@ -222,23 +226,51 @@ def update_requests(self, decoder_state: DecoderState) -> None:
222226
"new_tokens_host"].tolist()
223227
scheduled_requests = decoder_state.scheduled_requests
224228

225-
idx = 0
229+
request_idx = 0
230+
token_idx = 0
226231
beam_idx = 0
232+
233+
def advance_idx(num_tokens=1):
234+
nonlocal request_idx, token_idx
235+
request_idx += 1
236+
token_idx += num_tokens
237+
238+
def handle_logits(request: LlmRequest, count=1):
239+
if decoder_state.logits is None:
240+
return
241+
if not request.py_return_generation_logits and not request.py_return_log_probs:
242+
return
243+
244+
current_slice = slice(token_idx, token_idx + count)
245+
current_logits = decoder_state.logits[current_slice]
246+
247+
request.py_result.append_generation_logits(current_logits)
248+
249+
if not request.py_return_log_probs:
250+
return
251+
252+
if decoder_state.log_probs:
253+
log_probs = decoder_state.log_probs[request_idx]
254+
else:
255+
_, log_probs = greedy_search_sampling_batch(current_logits)
256+
request.py_result.append_log_probs([log_probs.tolist()])
257+
227258
for request in scheduled_requests.context_requests:
228259
if request.get_context_remaining_length() != 0:
229-
idx += 1
260+
advance_idx()
230261
continue
231262

232263
if request.state != LlmRequestState.GENERATION_COMPLETE:
233-
new_token = new_tokens_list[idx]
264+
new_token = new_tokens_list[token_idx]
234265
num_tokens = request.add_new_token(new_token, beam_idx)
235266
self._handle_stop_criteria(request, new_token, num_tokens,
236267
beam_idx)
268+
handle_logits(request)
237269
request.py_decoding_iter += 1
238-
idx += 1
270+
advance_idx()
239271

240272
if hasattr(scheduled_requests, 'chunked_requests'):
241-
idx += len(scheduled_requests.chunked_requests)
273+
request_idx += len(scheduled_requests.chunked_requests)
242274

243275
extend_requests = []
244276
generation_requests = []
@@ -249,41 +281,41 @@ def update_requests(self, decoder_state: DecoderState) -> None:
249281
generation_requests.append(request)
250282

251283
for request in extend_requests:
252-
num_accepted = 0
253284
if request.state != LlmRequestState.GENERATION_COMPLETE:
254-
new_token = new_tokens_list[idx]
285+
new_token = new_tokens_list[token_idx]
255286
num_tokens = request.add_new_token(new_token, beam_idx)
256287
self._handle_stop_criteria(request, new_token, num_tokens,
257288
beam_idx)
258-
request.py_decoding_iter += 1
259289

260290
# Accept draft tokens (if we have any) if and only if they match the new
261291
# token exactly.
262-
for i in range(len(request.py_draft_tokens)):
263-
draft_token = request.py_draft_tokens[i]
292+
num_accepted = 0
293+
for draft_token in request.py_draft_tokens:
264294
if draft_token != new_token:
265295
# Reject.
266296
break
267-
268297
num_accepted += 1
269-
new_token = new_tokens_list[idx + i + 1]
298+
new_token = new_tokens_list[token_idx + num_accepted]
270299
num_tokens = request.add_new_token(new_token, beam_idx)
271300

272301
if self._handle_stop_criteria(request, new_token,
273302
num_tokens, beam_idx):
274303
break
275-
request.py_num_accepted_draft_tokens = num_accepted
276-
request.py_rewind_len = request.py_draft_pages_allocated - num_accepted
277-
idx += len(request.py_draft_tokens) + 1
304+
handle_logits(request, num_accepted)
305+
request.py_decoding_iter += 1
306+
request.py_num_accepted_draft_tokens = num_accepted
307+
request.py_rewind_len = request.py_draft_pages_allocated - num_accepted
308+
advance_idx(len(request.py_draft_tokens) + 1)
278309

279310
for request in generation_requests:
280311
if request.state != LlmRequestState.GENERATION_COMPLETE:
281-
new_token = new_tokens_list[idx]
312+
new_token = new_tokens_list[token_idx]
282313
num_tokens = request.add_new_token(new_token, beam_idx)
283314
self._handle_stop_criteria(request, new_token, num_tokens,
284315
beam_idx)
316+
handle_logits(request)
285317
request.py_decoding_iter += 1
286-
idx += 1
318+
advance_idx()
287319

288320
def _mixed_decode(self, scheduled_requests: ScheduledRequests,
289321
model_outputs) -> DecoderState:
@@ -292,23 +324,33 @@ def _mixed_decode(self, scheduled_requests: ScheduledRequests,
292324
state = DecoderState(
293325
scheduled_requests=scheduled_requests,
294326
logits=logits,
327+
log_probs=[],
295328
)
296329

297330
new_tokens_device_array = []
298331

299332
idx = 0
300333

301334
for request in scheduled_requests.context_requests:
335+
assert not request.py_return_context_logits, "Return context logits not supported"
302336
token_logits = logits[idx:idx + 1, :]
303-
new_token = decode_single_request(request, token_logits)
337+
new_token, log_probs = decode_single_request(request, token_logits)
304338
new_tokens_device_array.append(new_token)
339+
log_probs = [log_probs.tolist()
340+
] if request.py_return_log_probs else None
341+
state.log_probs.append(log_probs) # Currently always beam_width=1
305342
idx += 1
306343

307344
for request in scheduled_requests.generation_requests:
345+
if request.state == LlmRequestState.GENERATION_COMPLETE:
346+
continue
308347
assert request.py_draft_tokens is None, "Speculative decoding not supported in SeparateDecoder."
309348
token_logits = logits[idx:idx + 1, :]
310-
new_token = decode_single_request(request, token_logits)
349+
new_token, log_probs = decode_single_request(request, token_logits)
311350
new_tokens_device_array.append(new_token)
351+
log_probs = [log_probs.tolist()
352+
] if request.py_return_log_probs else None
353+
state.log_probs.append(log_probs) # Currently always beam_width=1
312354
idx += 1
313355

314356
new_tokens_device = torch.cat(new_tokens_device_array)
@@ -351,6 +393,13 @@ def update_one_request(self, request: LlmRequest,
351393
output_token_idx = request.output_token_idx
352394
new_token = new_tokens_list[output_token_idx]
353395
num_tokens = request.add_new_token(new_token, beam_idx)
396+
397+
current_logits = logits[output_token_idx].unsqueeze(0)
398+
request.py_result.append_generation_logits(current_logits)
399+
if request.py_return_log_probs:
400+
_, log_probs = greedy_search_sampling_batch(current_logits)
401+
request.py_result.append_log_probs([log_probs.tolist()])
402+
354403
self._handle_stop_criteria(request, new_token, num_tokens, beam_idx)
355404
if request.state != LlmRequestState.GENERATION_COMPLETE:
356405
request.py_decoding_iter += 1

0 commit comments

Comments
 (0)