2
2
import time
3
3
from datetime import datetime
4
4
from typing import Any , AsyncGenerator , Dict , List , Optional , Union , cast
5
+ from dataclasses import dataclass
5
6
from fastapi .templating import Jinja2Templates
6
7
from fastapi import APIRouter , Form , Depends , Request
7
8
from fastapi .responses import StreamingResponse , HTMLResponse
14
15
from openai .types .beta import AssistantStreamEvent
15
16
from openai .lib .streaming ._assistants import AsyncAssistantEventHandler
16
17
from openai .types .beta .threads .run_submit_tool_outputs_params import ToolOutput
17
- from openai .types .beta .threads .run import RequiredAction
18
+ from openai .types .beta .threads .run import RequiredAction , Run
18
19
from fastapi .responses import StreamingResponse
19
20
from fastapi import APIRouter , Depends , Form , HTTPException
20
21
from pydantic import BaseModel
24
25
from utils .custom_functions import get_weather
25
26
from utils .sse import sse_format
26
27
28
+ @dataclass
29
+ class AssistantStreamMetadata :
30
+ """Metadata for assistant stream events that require further processing."""
31
+ type : str # Always "metadata"
32
+ required_action : Optional [RequiredAction ]
33
+ step_id : str
34
+ run_requires_action_event : Optional [ThreadRunRequiresAction ]
35
+
36
+ @classmethod
37
+ def create (cls ,
38
+ required_action : Optional [RequiredAction ],
39
+ step_id : str ,
40
+ run_requires_action_event : Optional [ThreadRunRequiresAction ]
41
+ ) -> "AssistantStreamMetadata" :
42
+ """Factory method to create a metadata instance with validation."""
43
+ return cls (
44
+ type = "metadata" ,
45
+ required_action = required_action ,
46
+ step_id = step_id ,
47
+ run_requires_action_event = run_requires_action_event
48
+ )
49
+
50
+ def requires_tool_call (self ) -> bool :
51
+ """Check if this metadata indicates a required tool call."""
52
+ return (self .required_action is not None
53
+ and self .required_action .submit_tool_outputs is not None
54
+ and bool (self .required_action .submit_tool_outputs .tool_calls ))
55
+
56
+ def get_run_id (self ) -> str :
57
+ """Get the run ID from the requires action event, or empty string if none."""
58
+ return self .run_requires_action_event .data .id if self .run_requires_action_event else ""
59
+
27
60
logger : logging .Logger = logging .getLogger ("uvicorn.error" )
28
61
logger .setLevel (logging .DEBUG )
29
62
@@ -125,10 +158,10 @@ async def handle_assistant_stream(
125
158
logger : logging .Logger ,
126
159
stream_manager : AsyncAssistantStreamManager ,
127
160
step_id : str = ""
128
- ) -> AsyncGenerator [Union [Dict [ str , Any ] , str ], None ]:
161
+ ) -> AsyncGenerator [Union [AssistantStreamMetadata , str ], None ]:
129
162
"""
130
163
Async generator to yield SSE events.
131
- We yield a final 'metadata' dictionary event once we're done.
164
+ We yield a final AssistantStreamMetadata instance once we're done.
132
165
"""
133
166
required_action : Optional [RequiredAction ] = None
134
167
run_requires_action_event : Optional [ThreadRunRequiresAction ] = None
@@ -218,18 +251,17 @@ async def handle_assistant_stream(
218
251
if isinstance (event , ThreadRunCompleted ):
219
252
yield sse_format ("endStream" , "DONE" )
220
253
221
- # At the end (or break) of this async generator, we yield a final "metadata" object
222
- yield {
223
- "type" : "metadata" ,
224
- "required_action" : required_action ,
225
- "step_id" : step_id ,
226
- "run_requires_action_event" : run_requires_action_event
227
- }
254
+ # At the end (or break) of this async generator, yield a final AssistantStreamMetadata
255
+ yield AssistantStreamMetadata .create (
256
+ required_action = required_action ,
257
+ step_id = step_id ,
258
+ run_requires_action_event = run_requires_action_event
259
+ )
228
260
229
261
async def event_generator () -> AsyncGenerator [str , None ]:
230
262
"""
231
263
Main generator for SSE events. We call our helper function to handle the assistant
232
- stream, and if the assistant requests a tool call, we do it and then re-run the stream.
264
+ stream, and if the assistant requests a tool call, we do it and then re-stream the stream.
233
265
"""
234
266
step_id : str = ""
235
267
stream_manager : AsyncAssistantStreamManager [AsyncAssistantEventHandler ] = client .beta .threads .runs .stream (
@@ -239,17 +271,13 @@ async def event_generator() -> AsyncGenerator[str, None]:
239
271
)
240
272
241
273
while True :
242
- event : dict [ str , Any ] | str
274
+ event : Union [ AssistantStreamMetadata , str ]
243
275
async for event in handle_assistant_stream (templates , logger , stream_manager , step_id ):
244
- # Detect the special "metadata" event at the end of the generator
245
- if isinstance (event , dict ) and event .get ("type" ) == "metadata" :
246
- required_action = cast (Optional [RequiredAction ], event .get ("required_action" ))
247
- step_id = cast (str , event .get ("step_id" , "" ))
248
- run_requires_action_event = cast (Optional [ThreadRunRequiresAction ], event .get ("run_requires_action_event" ))
249
-
250
- # If the assistant still needs a tool call, do it and then re-stream
251
- if required_action and required_action .submit_tool_outputs and required_action .submit_tool_outputs .tool_calls :
252
- for tool_call in required_action .submit_tool_outputs .tool_calls :
276
+ if isinstance (event , AssistantStreamMetadata ):
277
+ # Use the helper methods from our class
278
+ step_id = event .step_id
279
+ if event .requires_tool_call ():
280
+ for tool_call in event .required_action .submit_tool_outputs .tool_calls : # type: ignore
253
281
if tool_call .type == "function" :
254
282
try :
255
283
args = json .loads (tool_call .function .arguments )
@@ -283,7 +311,7 @@ async def event_generator() -> AsyncGenerator[str, None]:
283
311
"output" : str (weather_output ),
284
312
"tool_call_id" : tool_call .id
285
313
},
286
- "runId" : run_requires_action_event . data . id if run_requires_action_event else "" ,
314
+ "runId" : event . get_run_id () ,
287
315
}
288
316
except Exception as err :
289
317
error_message = f"Failed to get weather output: { err } "
@@ -294,7 +322,7 @@ async def event_generator() -> AsyncGenerator[str, None]:
294
322
"output" : error_message ,
295
323
"tool_call_id" : tool_call .id
296
324
},
297
- "runId" : run_requires_action_event . data . id if run_requires_action_event else "" ,
325
+ "runId" : event . get_run_id () ,
298
326
}
299
327
300
328
# Afterwards, create a fresh stream_manager for the next iteration
0 commit comments