77from typing import Any , Callable , Optional
88
99from fastapi import FastAPI , status
10+ from fastapi .responses import StreamingResponse
1011from opentelemetry .instrumentation .fastapi import FastAPIInstrumentor
1112from pydantic import BaseModel
1213from starlette .responses import RedirectResponse
@@ -110,16 +111,29 @@ class InvokeResponse(BaseModel):
110111
111112 logging .getLogger ("etl_uvicorn.fastapi" )
112113
113- async def wrap_fn (func : Callable , kwargs : Optional [dict [str , Any ]] = None ) -> InvokeResponse :
114+ ResponseType = StreamingResponse if inspect .isasyncgenfunction (func ) else InvokeResponse
115+
116+ async def wrap_fn (func : Callable , kwargs : Optional [dict [str , Any ]] = None ) -> ResponseType :
114117 usage : list [UsageData ] = []
115118 request_dict = kwargs if kwargs else {}
116119 if "usage" in inspect .signature (func ).parameters :
117120 request_dict ["usage" ] = usage
118121 else :
119122 logger .warning ("usage data not an expected parameter, omitting" )
120123 try :
121- output = await invoke_func (func = func , kwargs = request_dict )
122- return InvokeResponse (usage = usage , status_code = status .HTTP_200_OK , output = output )
124+ if inspect .isasyncgenfunction (func ):
125+ # Stream response if function is an async generator
126+
127+ async def _stream_response ():
128+ async for output in func (** (request_dict or {})):
129+ yield InvokeResponse (
130+ usage = usage , status_code = status .HTTP_200_OK , output = output
131+ ).model_dump_json () + "\n "
132+
133+ return StreamingResponse (_stream_response (), media_type = "application/x-ndjson" )
134+ else :
135+ output = await invoke_func (func = func , kwargs = request_dict )
136+ return InvokeResponse (usage = usage , status_code = status .HTTP_200_OK , output = output )
123137 except Exception as invoke_error :
124138 logger .error (f"failed to invoke plugin: { invoke_error } " , exc_info = True )
125139 return InvokeResponse (
@@ -132,7 +146,7 @@ async def wrap_fn(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> In
132146 if input_schema_model .model_fields :
133147
134148 @fastapi_app .post ("/invoke" , response_model = InvokeResponse )
135- async def run_job (request : input_schema_model ) -> InvokeResponse :
149+ async def run_job (request : input_schema_model ) -> ResponseType :
136150 log_func_and_body (func = func , body = request .json ())
137151 # Create dictionary from pydantic model while preserving underlying types
138152 request_dict = {f : getattr (request , f ) for f in request .model_fields }
@@ -144,7 +158,7 @@ async def run_job(request: input_schema_model) -> InvokeResponse:
144158 else :
145159
146160 @fastapi_app .post ("/invoke" , response_model = InvokeResponse )
147- async def run_job () -> InvokeResponse :
161+ async def run_job () -> ResponseType :
148162 log_func_and_body (func = func )
149163 return await wrap_fn (
150164 func = func ,
0 commit comments