Skip to content

Commit db8c63b

Browse files
authored
[TRTLLM-4517] [feat] Additional model outputs (#7206)
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
1 parent bbae7a0 commit db8c63b

File tree

11 files changed

+533
-31
lines changed

11 files changed

+533
-31
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,12 @@ def add_llm_args(parser):
157157
default=False,
158158
action='store_true')
159159
parser.add_argument('--logprobs', default=False, action='store_true')
160+
161+
parser.add_argument('--additional_model_outputs',
162+
type=str,
163+
default=None,
164+
nargs='+')
165+
160166
return parser
161167

162168

@@ -279,7 +285,8 @@ def setup_llm(args, **kwargs):
279285
logprobs=args.logprobs,
280286
n=args.n,
281287
best_of=best_of,
282-
use_beam_search=use_beam_search)
288+
use_beam_search=use_beam_search,
289+
additional_model_outputs=args.additional_model_outputs)
283290
return llm, sampling_params
284291

285292

@@ -319,6 +326,16 @@ def main():
319326
if args.logprobs:
320327
print(f"[{i}]{sequence_id_text} Logprobs: {sequence.logprobs}")
321328

329+
if args.additional_model_outputs:
330+
for output_name in args.additional_model_outputs:
331+
if sequence.additional_context_outputs:
332+
print(
333+
f"[{i}]{sequence_id_text} Context {output_name}: {sequence.additional_context_outputs[output_name]}"
334+
)
335+
print(
336+
f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}"
337+
)
338+
322339

323340
if __name__ == '__main__':
324341
main()
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from itertools import chain
2+
from typing import Dict, List
3+
4+
import torch
5+
6+
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
7+
from tensorrt_llm._utils import nvtx_range
8+
from tensorrt_llm.logger import logger
9+
10+
11+
class HandleAdditionalOutputs:
12+
13+
@torch.inference_mode()
14+
@nvtx_range("handle_additional_outputs")
15+
def __call__(
16+
self,
17+
context_requests: List[LlmRequest],
18+
generation_requests: List[LlmRequest],
19+
outputs: Dict[str, torch.Tensor],
20+
beam_width: int,
21+
num_context_tokens: int,
22+
):
23+
"""Handles context and generation logits for a batch of requests.
24+
25+
Args:
26+
context_requests: List of context requests to process
27+
generation_requests: List of generation requests to process
28+
outputs: Additional outputs tensors
29+
beam_width: Beam width for the generation requests
30+
num_context_tokens: Number of context tokens in the batch
31+
"""
32+
33+
additional_outputs = set()
34+
for r in chain(context_requests, generation_requests):
35+
if r.py_additional_outputs is not None:
36+
additional_outputs.update(r.py_additional_outputs)
37+
38+
if not additional_outputs:
39+
return
40+
41+
output_length_with_context = num_context_tokens + beam_width * len(
42+
generation_requests)
43+
output_length_without_context = len(
44+
context_requests) + beam_width * len(generation_requests)
45+
46+
gather_context = {}
47+
for name in additional_outputs:
48+
if outputs[name].shape[0] == output_length_with_context:
49+
gather_context[name] = True
50+
else:
51+
gather_context[name] = False
52+
53+
output_index_with_context = 0
54+
output_index_without_context = 0
55+
56+
# Copy additional outputs into decoderBuffers.additional_outputs
57+
for llm_req in context_requests:
58+
context_output_length = llm_req.context_chunk_size
59+
60+
outputs_begin = output_index_with_context
61+
outputs_end = output_index_with_context + context_output_length
62+
63+
additional_outputs = llm_req.py_additional_outputs
64+
req_context_output = False
65+
for name in additional_outputs:
66+
if gather_context[name]:
67+
output_device_view = outputs[name][
68+
outputs_begin:outputs_end]
69+
llm_req.py_result.append_additional_context_outputs(
70+
name, output_device_view)
71+
req_context_output = True
72+
73+
if req_context_output and llm_req.prepopulated_prompt_len > 0:
74+
logger.warning(
75+
f"Because of KV cache reuse, not all additional context outputs could be produced for request {llm_req.request_id}."
76+
)
77+
78+
output_index_with_context += context_output_length
79+
output_index_without_context += 1
80+
81+
if llm_req.is_last_context_chunk:
82+
for name in additional_outputs:
83+
outputs_begin = (output_index_with_context
84+
if gather_context[name] else
85+
output_index_without_context) - 1
86+
outputs_end = outputs_begin + 1
87+
88+
output_device_view = outputs[name][
89+
outputs_begin:outputs_end]
90+
llm_req.py_result.append_additional_generation_outputs(
91+
name, torch.tile(output_device_view,
92+
(1, beam_width, 1)))
93+
94+
for llm_req in generation_requests:
95+
additional_outputs = llm_req.py_additional_outputs
96+
97+
for name in additional_outputs:
98+
outputs_begin = (output_index_with_context
99+
if gather_context[name] else
100+
output_index_without_context)
101+
outputs_end = outputs_begin + beam_width
102+
103+
output_device_view = outputs[name][
104+
outputs_begin:outputs_end].reshape(1, beam_width, -1)
105+
llm_req.py_result.append_additional_generation_outputs(
106+
name, output_device_view)
107+
108+
output_index_with_context += beam_width
109+
output_index_without_context += beam_width
110+
111+
assert output_index_with_context == output_length_with_context, f"output_index_with_context: {output_index_with_context}, output_length_with_context: {output_length_with_context}"
112+
assert output_index_without_context == output_length_without_context, f"output_index_without_context: {output_index_without_context}, output_length_without_context: {output_length_without_context}"

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ def __init__(self,
234234
return_generation_logits: bool = False,
235235
exclude_last_generation_logits: bool = False,
236236
use_chunked_generation_logits: bool = True,
237-
chunk_size: int = 8):
237+
chunk_size: int = 8,
238+
additional_outputs: Optional[List[str]] = None):
238239
if streaming and use_chunked_generation_logits:
239240
assert chunk_size == 1, "chunk_size must be 1 in streaming mode"
240241
self._streaming = streaming
@@ -253,6 +254,14 @@ def __init__(self,
253254
chunk_size=self._chunk_size) if return_generation_logits else None
254255
self._log_probs = LogProbStorage() if return_log_probs else None
255256
self._mm_embeddings = None
257+
self._additional_context_outputs = {
258+
name: []
259+
for name in additional_outputs
260+
} if additional_outputs else None
261+
self._additional_generation_outputs = {
262+
name: []
263+
for name in additional_outputs
264+
} if additional_outputs else None
256265

257266
def append_context_logits(self, context_logits: torch.Tensor):
258267
if self._context_logits:
@@ -277,6 +286,16 @@ def transfer_remaining_device_logits(self):
277286
if self._generation_logits:
278287
self._generation_logits.finalize_chunked_transfer()
279288

289+
def append_additional_context_outputs(
290+
self, name: str, additional_context_outputs: torch.Tensor):
291+
self._additional_context_outputs[name].append(
292+
additional_context_outputs.to("cpu", non_blocking=True))
293+
294+
def append_additional_generation_outputs(
295+
self, name: str, additional_generation_outputs: torch.Tensor):
296+
self._additional_generation_outputs[name].append(
297+
additional_generation_outputs.to("cpu", non_blocking=True))
298+
280299
def set_log_probs(self, log_probs: list[TokenLogprobs],
281300
cum_log_probs: list[float]):
282301
"""
@@ -318,12 +337,37 @@ def cum_log_probs(self) -> list[float] | None:
318337
def mm_embedding_handle(self) -> Dict[str, Any] | None:
319338
return self._mm_embeddings
320339

340+
@property
341+
def additional_context_outputs(self) -> Dict[str, torch.Tensor] | None:
342+
if self._additional_context_outputs is None:
343+
return None
344+
outputs = {}
345+
for name, output_list in self._additional_context_outputs.items():
346+
if len(output_list) == 0:
347+
continue
348+
outputs[name] = torch.cat(
349+
output_list, dim=0) if len(output_list) > 1 else output_list[0]
350+
return outputs
351+
352+
@property
353+
def additional_generation_outputs(self) -> Dict[str, torch.Tensor] | None:
354+
if self._additional_generation_outputs is None:
355+
return None
356+
outputs = {}
357+
for name, output_list in self._additional_generation_outputs.items():
358+
if len(output_list) == 0:
359+
continue
360+
outputs[name] = torch.cat(
361+
output_list, dim=0) if len(output_list) > 1 else output_list[0]
362+
return outputs
363+
321364

322365
class LlmResult:
323366
"""LlmResult wraps `bindings.executor.Result` but detour some features to Python implementation"""
324367
py_result_properties = frozenset(
325368
('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs',
326-
'mm_embedding_handle'))
369+
'mm_embedding_handle', 'additional_context_outputs',
370+
'additional_generation_outputs'))
327371

328372
def __init__(self,
329373
result: Union[bytes, tensorrt_llm.bindings.executor.Result],
@@ -388,6 +432,7 @@ def __init__(
388432
return_generation_logits: bool = False,
389433
return_logits_device_memory: bool = True,
390434
exclude_last_generation_logits: bool = False,
435+
additional_outputs: Optional[List[str]] = None,
391436
return_perf_metrics: bool = False,
392437
stop_words_list: list[list[int]] | None = None,
393438
llm_request: Optional[
@@ -448,6 +493,8 @@ def __init__(
448493
self.py_return_context_logits = return_context_logits
449494
self.py_return_generation_logits = return_generation_logits
450495
self.py_return_logits_device_memory = return_logits_device_memory
496+
self.py_additional_outputs = additional_outputs
497+
451498
self.py_is_draft = is_draft
452499
# The request's sequence slot ID, an index between 0 (inclusive) and max_batch_size (exclusive).
453500
self.py_seq_slot = seq_slot
@@ -477,7 +524,8 @@ def __init__(
477524
return_generation_logits,
478525
exclude_last_generation_logits,
479526
use_chunked_generation_logits=self.py_use_chunked_generation_logits,
480-
chunk_size=self.py_logits_chunk_size)
527+
chunk_size=self.py_logits_chunk_size,
528+
additional_outputs=additional_outputs)
481529
self.child_requests = []
482530

483531
self._py_embedding_bias_1d: Optional[torch.Tensor] = None
@@ -675,6 +723,11 @@ def executor_request_to_llm_request(
675723
return_generation_logits=executor_request.output_config.
676724
return_generation_logits,
677725
exclude_last_generation_logits=exclude_last_generation_logits,
726+
additional_outputs=[
727+
output.name for output in
728+
executor_request.output_config.additional_model_outputs
729+
] if executor_request.output_config.additional_model_outputs is not None
730+
else None,
678731
draft_tokens=getattr(executor_request, "draft_tokens", None),
679732
draft_logits=None,
680733
exclude_input_from_output=executor_request.output_config.

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2264,21 +2264,34 @@ def _forward_step(self,
22642264
inputs = self._preprocess_inputs(inputs)
22652265
if inputs.get('spec_metadata', None):
22662266
gather_ids = inputs['spec_metadata'].gather_ids
2267-
if self.without_logits:
2268-
outputs = self.model_forward(**inputs)
2269-
return outputs
22702267

22712268
# For simplicity, just return all the the logits if we have special gather_ids
22722269
# from speculative decoding.
2273-
logits = self.model_forward(
2270+
outputs = self.model_forward(
22742271
**inputs,
22752272
return_context_logits=gather_ids is not None
22762273
or gather_context_logits,
22772274
)
2278-
if gather_ids is not None:
2279-
return {'logits': logits[gather_ids]}
2275+
2276+
if self.without_logits:
2277+
return outputs
2278+
2279+
if isinstance(outputs, dict):
2280+
# If the model returns a dict, get the logits from it. All other keys are kept.
2281+
logits = outputs.get('logits', None)
2282+
# If the logits are not found, no further processing is needed.
2283+
if logits is None:
2284+
return outputs
22802285
else:
2281-
return {'logits': logits}
2286+
# If the model returns a single tensor, assume it is the logits and wrap it in a dict.
2287+
logits = outputs
2288+
outputs = {'logits': logits}
2289+
2290+
# If we have special gather_ids, gather the logits
2291+
if gather_ids is not None:
2292+
outputs['logits'] = logits[gather_ids]
2293+
2294+
return outputs
22822295

22832296
@nvtx_range("_forward_step_mm_encoder_only")
22842297
def _forward_step_mm_encoder_only(

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from ..speculative.drafter import Drafter
4343
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
4444
from .guided_decoder import GuidedDecoder
45+
from .handle_additional_outputs import HandleAdditionalOutputs
4546
from .handle_logits import HandleLogits
4647
from .kv_cache_connector import KvCacheConnectorManager
4748
from .kv_cache_transceiver import KvCacheTransceiver
@@ -1815,18 +1816,27 @@ def _sample_async(self, scheduled_batch,
18151816
if batch_outputs is not None:
18161817
num_context_logits_prefix_sum = [0]
18171818
prefix_sum = 0
1819+
num_context_tokens = 0
18181820
for request in scheduled_batch.context_requests:
1819-
prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1
1821+
context_chunk_size = request.context_chunk_size
1822+
prefix_sum += context_chunk_size if request.py_return_context_logits else 1
18201823
num_context_logits_prefix_sum.append(prefix_sum)
1824+
num_context_tokens += context_chunk_size
1825+
1826+
beam_width = self.sampler.beam_width(
1827+
scheduled_batch.all_requests())
18211828

18221829
HandleLogits()(scheduled_batch.context_requests,
18231830
scheduled_batch.generation_requests,
1824-
batch_outputs["logits"],
1825-
self.sampler.beam_width(
1826-
scheduled_batch.all_requests()),
1831+
batch_outputs["logits"], beam_width,
18271832
num_context_logits_prefix_sum,
18281833
self.sampler.is_generation_model())
18291834

1835+
HandleAdditionalOutputs()(scheduled_batch.context_requests,
1836+
scheduled_batch.generation_requests,
1837+
batch_outputs, beam_width,
1838+
num_context_tokens)
1839+
18301840
return self.sampler.sample_async(scheduled_batch, batch_outputs,
18311841
num_context_logits_prefix_sum)
18321842
except Exception as e:

tensorrt_llm/executor/result.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ class CompletionOutput:
102102
finish_reason (Literal['stop', 'length', 'timeout', 'cancelled'], optional): The reason why the sequence is finished. Defaults to None.
103103
stop_reason (int, str, optional): The stop string or token id that caused the completion to stop, None if the completion finished for some other reason. Defaults to None.
104104
generation_logits (torch.Tensor, optional): The logits on the generated output token ids. Defaults to None.
105+
additional_context_outputs (Dict[str, torch.Tensor], optional): The additional context outputs. Defaults to None.
106+
additional_generation_outputs (Dict[str, torch.Tensor], optional): The additional generation outputs. Defaults to None.
105107
disaggregated_params (tensorrt_llm.disaggregated_params.DisaggregatedParams, optional): Parameters needed for disaggregated serving. Includes the type of request, the first generated tokens, the context request id and the any additional state needing to be transferred from context and generation instances. Defaults to None.
106108
request_perf_metrics (tensorrt_llm.bindings.executor.RequestPerfMetrics, optional): Performance metrics for the request. Defaults to None.
107109
@@ -122,6 +124,8 @@ class CompletionOutput:
122124
'cancelled']] = None
123125
stop_reason: Optional[Union[int, str]] = None
124126
generation_logits: Optional[torch.Tensor] = None
127+
additional_context_outputs: Optional[Dict[str, torch.Tensor]] = None
128+
additional_generation_outputs: Optional[Dict[str, torch.Tensor]] = None
125129
disaggregated_params: Optional[DisaggregatedParams] = None
126130
request_perf_metrics: Optional[tllm.RequestPerfMetrics] = None
127131

@@ -387,6 +391,14 @@ def _handle_sequence(self,
387391
output.generation_logits = response_tensors.generation_logits[
388392
src_idx, :output.length]
389393

394+
if getattr(response_tensors, 'additional_context_outputs',
395+
None) is not None:
396+
output.additional_context_outputs = response_tensors.additional_context_outputs
397+
398+
if getattr(response_tensors, 'additional_generation_outputs',
399+
None) is not None:
400+
output.additional_generation_outputs = response_tensors.additional_generation_outputs
401+
390402
# when sampling_params.n > 1 and is cancelled, make sure all the outputs
391403
# be marked as cancelled.
392404
if finish_reasons and finish_reasons[

0 commit comments

Comments
 (0)