Skip to content

Commit aaa87ab

Browse files
authored
[TRTLLM-7906][feat] Support multiple post process for Responses API (#9908)
Signed-off-by: Junyi Xu <[email protected]>
1 parent ba14a93 commit aaa87ab

File tree

4 files changed

+356
-133
lines changed

4 files changed

+356
-133
lines changed

tensorrt_llm/serve/openai_server.py

Lines changed: 78 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,22 @@
5151
MemoryUpdateRequest, ModelCard,
5252
ModelList, PromptTokensDetails,
5353
ResponsesRequest,
54+
ResponsesResponse,
5455
UpdateWeightsRequest, UsageInfo,
5556
to_llm_disaggregated_params)
5657
from tensorrt_llm.serve.postprocess_handlers import (
5758
ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs,
58-
chat_harmony_post_processor, chat_harmony_streaming_post_processor,
59-
chat_response_post_processor, chat_stream_post_processor,
60-
completion_response_post_processor, completion_stream_post_processor)
59+
ResponsesAPIPostprocArgs, chat_harmony_post_processor,
60+
chat_harmony_streaming_post_processor, chat_response_post_processor,
61+
chat_stream_post_processor, completion_response_post_processor,
62+
completion_stream_post_processor, responses_api_post_processor,
63+
responses_api_streaming_post_processor)
6164
from tensorrt_llm.serve.responses_utils import (ConversationHistoryStore,
65+
ResponsesStreamingProcessor,
6266
ServerArrivalTimeMiddleware)
6367
from tensorrt_llm.serve.responses_utils import \
6468
create_response as responses_api_create_response
6569
from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds
66-
from tensorrt_llm.serve.responses_utils import \
67-
process_streaming_events as responses_api_process_streaming_events
6870
from tensorrt_llm.serve.responses_utils import \
6971
request_preprocess as responses_api_request_preprocess
7072
from tensorrt_llm.version import __version__ as VERSION
@@ -119,9 +121,8 @@ def __init__(self,
119121
self.model_config = None
120122

121123
# Enable response storage for Responses API
122-
self.enable_store = True
123-
if len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) > 0:
124-
self.enable_store = False
124+
self.enable_store = (len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) < 1) and not self.postproc_worker_enabled
125+
125126
self.conversation_store = ConversationHistoryStore()
126127

127128
model_dir = Path(model)
@@ -942,19 +943,39 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po
942943
return self.create_error_response(message=str(e), err_type="internal_error")
943944

944945
async def openai_responses(self, request: ResponsesRequest, raw_request: Request) -> Response:
945-
async def create_stream_response(generator, request: ResponsesRequest, sampling_params) -> AsyncGenerator[str, None]:
946-
async for event_data in responses_api_process_streaming_events(
947-
request=request,
948-
sampling_params=sampling_params,
949-
generator=generator,
950-
model_name=self.model,
951-
conversation_store=self.conversation_store,
952-
use_harmony=self.use_harmony,
953-
reasoning_parser=self.llm.args.reasoning_parser,
954-
tool_parser=self.tool_parser,
955-
enable_store=self.enable_store
956-
):
957-
yield event_data
946+
async def create_response(
947+
promise: RequestOutput, postproc_params: PostprocParams) -> ResponsesResponse:
948+
await promise.aresult()
949+
if self.postproc_worker_enabled:
950+
response = promise.outputs[0]._postprocess_result
951+
else:
952+
args = postproc_params.postproc_args
953+
response = await responses_api_create_response(
954+
generator=promise,
955+
request=request,
956+
sampling_params=args.sampling_params,
957+
model_name=self.model,
958+
conversation_store=self.conversation_store,
959+
generation_result=None,
960+
enable_store=self.enable_store and request.store,
961+
use_harmony=self.use_harmony,
962+
reasoning_parser=args.reasoning_parser,
963+
tool_parser=args.tool_parser,
964+
)
965+
966+
return response
967+
968+
async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams):
969+
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
970+
streaming_processor = args.streaming_processor
971+
initial_responses = streaming_processor.get_initial_responses()
972+
for initial_response in initial_responses:
973+
yield initial_response
974+
975+
async for res in promise:
976+
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
977+
for pp_res in pp_results:
978+
yield pp_res
958979

959980
try:
960981
if request.background:
@@ -977,38 +998,61 @@ async def create_stream_response(generator, request: ResponsesRequest, sampling_
977998
request=request,
978999
prev_response=prev_response,
9791000
conversation_store=self.conversation_store,
980-
enable_store=self.enable_store,
1001+
enable_store=self.enable_store and request.store,
9811002
use_harmony=self.use_harmony,
9821003
tokenizer=self.tokenizer if not self.use_harmony else None,
9831004
model_config=self.model_config if not self.use_harmony else None,
9841005
processor=self.processor if not self.use_harmony else None,
9851006
)
9861007

1008+
streaming_processor = None
1009+
if request.stream:
1010+
# Per-request streaming processor
1011+
streaming_processor = ResponsesStreamingProcessor(
1012+
request=request,
1013+
sampling_params=sampling_params,
1014+
model_name=self.model,
1015+
conversation_store=self.conversation_store,
1016+
enable_store=self.enable_store and request.store,
1017+
use_harmony=self.use_harmony,
1018+
reasoning_parser=self.llm.args.reasoning_parser,
1019+
tool_parser=self.tool_parser,
1020+
)
1021+
1022+
postproc_args = ResponsesAPIPostprocArgs(
1023+
model=self.model,
1024+
request=request,
1025+
sampling_params=sampling_params,
1026+
use_harmony=self.use_harmony,
1027+
reasoning_parser=self.llm.args.reasoning_parser,
1028+
tool_parser=self.tool_parser,
1029+
streaming_processor=streaming_processor,
1030+
)
1031+
postproc_params = PostprocParams(
1032+
post_processor=responses_api_streaming_post_processor
1033+
if request.stream else responses_api_post_processor,
1034+
postproc_args=postproc_args,
1035+
)
9871036
promise = self.llm.generate_async(
9881037
inputs=input_tokens,
9891038
sampling_params=sampling_params,
9901039
streaming=request.stream,
1040+
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
9911041
)
9921042

1043+
if self.postproc_worker_enabled and request.store:
1044+
logger.warning("Postproc workers are enabled, request will not be stored!")
1045+
9931046
asyncio.create_task(self.await_disconnected(raw_request, promise))
9941047

9951048
if request.stream:
9961049
return StreamingResponse(
997-
create_stream_response(promise, request, sampling_params),
1050+
content=create_streaming_generator(promise, postproc_params),
9981051
media_type="text/event-stream"
9991052
)
10001053
else:
1001-
return await responses_api_create_response(
1002-
generator=promise,
1003-
request=request,
1004-
sampling_params=sampling_params,
1005-
model_name=self.model,
1006-
conversation_store=self.conversation_store,
1007-
generation_result=None,
1008-
enable_store=self.enable_store,
1009-
use_harmony=self.use_harmony,
1010-
reasoning_parser=self.llm.args.reasoning_parser,
1011-
tool_parser=self.tool_parser)
1054+
response = await create_response(promise, postproc_params)
1055+
return JSONResponse(content=response.model_dump())
10121056
except CppExecutorError:
10131057
logger.error(traceback.format_exc())
10141058
# If internal executor error is raised, shutdown the server

tensorrt_llm/serve/postprocess_handlers.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from dataclasses import dataclass, field
22
from typing import Any, List, Literal, Optional, Tuple, Union
33

4+
from tensorrt_llm.serve.responses_utils import ResponsesStreamingProcessor
5+
from tensorrt_llm.serve.responses_utils import \
6+
create_response_non_store as responses_api_create_response_non_store
7+
48
from .._utils import nvtx_range_debug
59
from ..executor import (DetokenizedGenerationResultBase, GenerationResult,
610
GenerationResultBase)
711
from ..executor.postproc_worker import PostprocArgs
812
from ..executor.result import Logprob, TokenLogprobs
13+
from ..llmapi import SamplingParams
914
from ..llmapi.reasoning_parser import (BaseReasoningParser,
1015
ReasoningParserFactory)
1116
from ..llmapi.tokenizer import TransformersTokenizer
@@ -26,7 +31,8 @@
2631
CompletionResponseStreamChoice,
2732
CompletionStreamResponse, DeltaFunctionCall,
2833
DeltaMessage, DeltaToolCall, FunctionCall,
29-
PromptTokensDetails, StreamOptions, ToolCall,
34+
PromptTokensDetails, ResponsesRequest,
35+
ResponsesResponse, StreamOptions, ToolCall,
3036
UsageInfo, to_disaggregated_params)
3137
from .tool_parser.base_tool_parser import BaseToolParser
3238
from .tool_parser.core_types import ToolCallItem
@@ -543,3 +549,42 @@ def chat_harmony_streaming_post_processor(
543549
num_prompt_tokens=args.num_prompt_tokens,
544550
)
545551
return response
552+
553+
554+
@dataclass(kw_only=True)
555+
class ResponsesAPIPostprocArgs(PostprocArgs):
556+
model: str
557+
request: ResponsesRequest
558+
sampling_params: SamplingParams
559+
use_harmony: bool
560+
reasoning_parser: Optional[str] = None
561+
tool_parser: Optional[str] = None
562+
streaming_processor: Optional[ResponsesStreamingProcessor] = None
563+
564+
565+
@nvtx_range_debug("responses_api_post_processor")
566+
def responses_api_post_processor(
567+
rsp: GenerationResult,
568+
args: ResponsesAPIPostprocArgs) -> ResponsesResponse:
569+
return responses_api_create_response_non_store(
570+
generation_result=rsp,
571+
request=args.request,
572+
sampling_params=args.sampling_params,
573+
model_name=args.model,
574+
use_harmony=args.use_harmony,
575+
reasoning_parser=args.reasoning_parser,
576+
tool_parser=args.tool_parser,
577+
)
578+
579+
580+
@nvtx_range_debug("responses_api_streaming_post_processor")
581+
def responses_api_streaming_post_processor(
582+
rsp: GenerationResult, args: ResponsesAPIPostprocArgs) -> List[str]:
583+
if args.streaming_processor is None:
584+
raise ValueError(
585+
"streaming_processor is required for streaming post-processing")
586+
outputs = args.streaming_processor.process_single_output(rsp)
587+
if rsp._done:
588+
outputs.append(
589+
args.streaming_processor.get_final_response_non_store(rsp))
590+
return outputs

0 commit comments

Comments
 (0)