5151 MemoryUpdateRequest , ModelCard ,
5252 ModelList , PromptTokensDetails ,
5353 ResponsesRequest ,
54+ ResponsesResponse ,
5455 UpdateWeightsRequest , UsageInfo ,
5556 to_llm_disaggregated_params )
5657from tensorrt_llm .serve .postprocess_handlers import (
5758 ChatCompletionPostprocArgs , ChatPostprocArgs , CompletionPostprocArgs ,
58- chat_harmony_post_processor , chat_harmony_streaming_post_processor ,
59- chat_response_post_processor , chat_stream_post_processor ,
60- completion_response_post_processor , completion_stream_post_processor )
59+ ResponsesAPIPostprocArgs , chat_harmony_post_processor ,
60+ chat_harmony_streaming_post_processor , chat_response_post_processor ,
61+ chat_stream_post_processor , completion_response_post_processor ,
62+ completion_stream_post_processor , responses_api_post_processor ,
63+ responses_api_streaming_post_processor )
6164from tensorrt_llm .serve .responses_utils import (ConversationHistoryStore ,
65+ ResponsesStreamingProcessor ,
6266 ServerArrivalTimeMiddleware )
6367from tensorrt_llm .serve .responses_utils import \
6468 create_response as responses_api_create_response
6569from tensorrt_llm .serve .responses_utils import get_steady_clock_now_in_seconds
66- from tensorrt_llm .serve .responses_utils import \
67- process_streaming_events as responses_api_process_streaming_events
6870from tensorrt_llm .serve .responses_utils import \
6971 request_preprocess as responses_api_request_preprocess
7072from tensorrt_llm .version import __version__ as VERSION
@@ -119,9 +121,8 @@ def __init__(self,
119121 self .model_config = None
120122
121123 # Enable response storage for Responses API
122- self .enable_store = True
123- if len (os .getenv ("TRTLLM_RESPONSES_API_DISABLE_STORE" , "" )) > 0 :
124- self .enable_store = False
124+ self .enable_store = (len (os .getenv ("TRTLLM_RESPONSES_API_DISABLE_STORE" , "" )) < 1 ) and not self .postproc_worker_enabled
125+
125126 self .conversation_store = ConversationHistoryStore ()
126127
127128 model_dir = Path (model )
@@ -942,19 +943,39 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po
942943 return self .create_error_response (message = str (e ), err_type = "internal_error" )
943944
944945 async def openai_responses (self , request : ResponsesRequest , raw_request : Request ) -> Response :
945- async def create_stream_response (generator , request : ResponsesRequest , sampling_params ) -> AsyncGenerator [str , None ]:
946- async for event_data in responses_api_process_streaming_events (
947- request = request ,
948- sampling_params = sampling_params ,
949- generator = generator ,
950- model_name = self .model ,
951- conversation_store = self .conversation_store ,
952- use_harmony = self .use_harmony ,
953- reasoning_parser = self .llm .args .reasoning_parser ,
954- tool_parser = self .tool_parser ,
955- enable_store = self .enable_store
956- ):
957- yield event_data
946+ async def create_response (
947+ promise : RequestOutput , postproc_params : PostprocParams ) -> ResponsesResponse :
948+ await promise .aresult ()
949+ if self .postproc_worker_enabled :
950+ response = promise .outputs [0 ]._postprocess_result
951+ else :
952+ args = postproc_params .postproc_args
953+ response = await responses_api_create_response (
954+ generator = promise ,
955+ request = request ,
956+ sampling_params = args .sampling_params ,
957+ model_name = self .model ,
958+ conversation_store = self .conversation_store ,
959+ generation_result = None ,
960+ enable_store = self .enable_store and request .store ,
961+ use_harmony = self .use_harmony ,
962+ reasoning_parser = args .reasoning_parser ,
963+ tool_parser = args .tool_parser ,
964+ )
965+
966+ return response
967+
968+ async def create_streaming_generator (promise : RequestOutput , postproc_params : PostprocParams ):
969+ post_processor , args = postproc_params .post_processor , postproc_params .postproc_args
970+ streaming_processor = args .streaming_processor
971+ initial_responses = streaming_processor .get_initial_responses ()
972+ for initial_response in initial_responses :
973+ yield initial_response
974+
975+ async for res in promise :
976+ pp_results = res .outputs [0 ]._postprocess_result if self .postproc_worker_enabled else post_processor (res , args )
977+ for pp_res in pp_results :
978+ yield pp_res
958979
959980 try :
960981 if request .background :
@@ -977,38 +998,61 @@ async def create_stream_response(generator, request: ResponsesRequest, sampling_
977998 request = request ,
978999 prev_response = prev_response ,
9791000 conversation_store = self .conversation_store ,
980- enable_store = self .enable_store ,
1001+ enable_store = self .enable_store and request . store ,
9811002 use_harmony = self .use_harmony ,
9821003 tokenizer = self .tokenizer if not self .use_harmony else None ,
9831004 model_config = self .model_config if not self .use_harmony else None ,
9841005 processor = self .processor if not self .use_harmony else None ,
9851006 )
9861007
1008+ streaming_processor = None
1009+ if request .stream :
1010+ # Per-request streaming processor
1011+ streaming_processor = ResponsesStreamingProcessor (
1012+ request = request ,
1013+ sampling_params = sampling_params ,
1014+ model_name = self .model ,
1015+ conversation_store = self .conversation_store ,
1016+ enable_store = self .enable_store and request .store ,
1017+ use_harmony = self .use_harmony ,
1018+ reasoning_parser = self .llm .args .reasoning_parser ,
1019+ tool_parser = self .tool_parser ,
1020+ )
1021+
1022+ postproc_args = ResponsesAPIPostprocArgs (
1023+ model = self .model ,
1024+ request = request ,
1025+ sampling_params = sampling_params ,
1026+ use_harmony = self .use_harmony ,
1027+ reasoning_parser = self .llm .args .reasoning_parser ,
1028+ tool_parser = self .tool_parser ,
1029+ streaming_processor = streaming_processor ,
1030+ )
1031+ postproc_params = PostprocParams (
1032+ post_processor = responses_api_streaming_post_processor
1033+ if request .stream else responses_api_post_processor ,
1034+ postproc_args = postproc_args ,
1035+ )
9871036 promise = self .llm .generate_async (
9881037 inputs = input_tokens ,
9891038 sampling_params = sampling_params ,
9901039 streaming = request .stream ,
1040+ _postproc_params = postproc_params if self .postproc_worker_enabled else None ,
9911041 )
9921042
1043+ if self .postproc_worker_enabled and request .store :
1044+ logger .warning ("Postproc workers are enabled, request will not be stored!" )
1045+
9931046 asyncio .create_task (self .await_disconnected (raw_request , promise ))
9941047
9951048 if request .stream :
9961049 return StreamingResponse (
997- create_stream_response (promise , request , sampling_params ),
1050+ content = create_streaming_generator (promise , postproc_params ),
9981051 media_type = "text/event-stream"
9991052 )
10001053 else :
1001- return await responses_api_create_response (
1002- generator = promise ,
1003- request = request ,
1004- sampling_params = sampling_params ,
1005- model_name = self .model ,
1006- conversation_store = self .conversation_store ,
1007- generation_result = None ,
1008- enable_store = self .enable_store ,
1009- use_harmony = self .use_harmony ,
1010- reasoning_parser = self .llm .args .reasoning_parser ,
1011- tool_parser = self .tool_parser )
1054+ response = await create_response (promise , postproc_params )
1055+ return JSONResponse (content = response .model_dump ())
10121056 except CppExecutorError :
10131057 logger .error (traceback .format_exc ())
10141058 # If internal executor error is raised, shutdown the server
0 commit comments