1
1
import logging
2
2
import time
3
3
from datetime import datetime
4
- from typing import Any , AsyncGenerator
4
+ from typing import Any , AsyncGenerator , Dict , List , Optional , Union , cast
5
5
from fastapi .templating import Jinja2Templates
6
6
from fastapi import APIRouter , Form , Depends , Request
7
7
from fastapi .responses import StreamingResponse , HTMLResponse
38
38
39
39
# Utility function for submitting tool outputs to the assistant
40
40
class ToolCallOutputs (BaseModel ):
41
- tool_outputs : Any
41
+ tool_outputs : Dict [ str , Any ]
42
42
runId : str
43
43
44
- async def post_tool_outputs (client : AsyncOpenAI , data : dict , thread_id : str ):
44
+ async def post_tool_outputs (client : AsyncOpenAI , data : Dict [ str , Any ], thread_id : str ) -> AsyncAssistantStreamManager :
45
45
"""
46
46
data is expected to be something like
47
47
{
@@ -55,7 +55,7 @@ async def post_tool_outputs(client: AsyncOpenAI, data: dict, thread_id: str):
55
55
try :
56
56
outputs_list = [
57
57
ToolOutput (
58
- output = data ["tool_outputs" ]["output" ],
58
+ output = str ( data ["tool_outputs" ]["output" ]) ,
59
59
tool_call_id = data ["tool_outputs" ]["tool_call_id" ]
60
60
)
61
61
]
@@ -124,14 +124,14 @@ async def handle_assistant_stream(
124
124
templates : Jinja2Templates ,
125
125
logger : logging .Logger ,
126
126
stream_manager : AsyncAssistantStreamManager ,
127
- step_id : int = 0
128
- ) -> AsyncGenerator :
127
+ step_id : str = ""
128
+ ) -> AsyncGenerator [ Union [ Dict [ str , Any ], str ], None ] :
129
129
"""
130
130
Async generator to yield SSE events.
131
131
We yield a final 'metadata' dictionary event once we're done.
132
132
"""
133
- required_action : RequiredAction | None = None
134
- run_requires_action_event : ThreadRunRequiresAction | None = None
133
+ required_action : Optional [ RequiredAction ] = None
134
+ run_requires_action_event : Optional [ ThreadRunRequiresAction ] = None
135
135
136
136
event_handler : AsyncAssistantEventHandler
137
137
async with stream_manager as event_handler :
@@ -149,11 +149,13 @@ async def handle_assistant_stream(
149
149
)
150
150
time .sleep (0.25 ) # Give the client time to render the message
151
151
152
- if isinstance (event , ThreadMessageDelta ):
153
- yield sse_format (
154
- f"textDelta{ step_id } " ,
155
- event .data .delta .content [0 ].text .value
156
- )
152
+ if isinstance (event , ThreadMessageDelta ) and event .data .delta .content :
153
+ content = event .data .delta .content [0 ]
154
+ if hasattr (content , 'text' ) and content .text and content .text .value :
155
+ yield sse_format (
156
+ f"textDelta{ step_id } " ,
157
+ content .text .value
158
+ )
157
159
158
160
if isinstance (event , ThreadRunStepCreated ) and event .data .type == "tool_calls" :
159
161
step_id = event .data .id
@@ -167,47 +169,50 @@ async def handle_assistant_stream(
167
169
)
168
170
time .sleep (0.25 ) # Give the client time to render the message
169
171
170
- if isinstance (event , ThreadRunStepDelta ) and event .data .delta .step_details .type == "tool_calls" :
171
- tool_call = event .data .delta .step_details .tool_calls [0 ]
172
-
173
- # Handle function tool calls
174
- if tool_call .type == "function" :
175
- if tool_call .function .name :
176
- yield sse_format (
177
- f"toolDelta{ step_id } " ,
178
- tool_call .function .name + "<br>"
179
- )
180
- elif tool_call .function .arguments :
181
- yield sse_format (
182
- f"toolDelta{ step_id } " ,
183
- tool_call .function .arguments
184
- )
185
-
186
- # Handle code interpreter tool calls
187
- elif tool_call .type == "code_interpreter" :
188
- if tool_call .code_interpreter .input :
189
- yield sse_format (
190
- f"toolDelta{ step_id } " ,
191
- f"{ tool_call .code_interpreter .input } "
192
- )
193
- if tool_call .code_interpreter .outputs :
194
- for output in tool_call .code_interpreter .outputs :
195
- if output .type == "logs" :
196
- yield sse_format (
197
- f"toolDelta{ step_id } " ,
198
- f"{ output .logs } "
199
- )
200
- elif output .type == "image" :
201
- yield sse_format (
202
- f"toolDelta{ step_id } " ,
203
- f"{ output .image .file_id } "
204
- )
172
+ if isinstance (event , ThreadRunStepDelta ) and event .data .delta .step_details and event .data .delta .step_details .type == "tool_calls" :
173
+ tool_calls = event .data .delta .step_details .tool_calls
174
+ if tool_calls :
175
+ # TODO: Support parallel function calling
176
+ tool_call = tool_calls [0 ]
177
+
178
+ # Handle function tool call
179
+ if tool_call .type == "function" :
180
+ if tool_call .function and tool_call .function .name :
181
+ yield sse_format (
182
+ f"toolDelta{ step_id } " ,
183
+ tool_call .function .name + "<br>"
184
+ )
185
+ if tool_call .function and tool_call .function .arguments :
186
+ yield sse_format (
187
+ f"toolDelta{ step_id } " ,
188
+ tool_call .function .arguments
189
+ )
190
+
191
+ # Handle code interpreter tool calls
192
+ elif tool_call .type == "code_interpreter" :
193
+ if tool_call .code_interpreter and tool_call .code_interpreter .input :
194
+ yield sse_format (
195
+ f"toolDelta{ step_id } " ,
196
+ str (tool_call .code_interpreter .input )
197
+ )
198
+ if tool_call .code_interpreter and tool_call .code_interpreter .outputs :
199
+ for output in tool_call .code_interpreter .outputs :
200
+ if output .type == "logs" and output .logs :
201
+ yield sse_format (
202
+ f"toolDelta{ step_id } " ,
203
+ str (output .logs )
204
+ )
205
+ elif output .type == "image" and output .image and output .image .file_id :
206
+ yield sse_format (
207
+ f"toolDelta{ step_id } " ,
208
+ str (output .image .file_id )
209
+ )
205
210
206
211
# If the assistant run requires an action (a tool call), break and handle it
207
212
if isinstance (event , ThreadRunRequiresAction ):
208
213
required_action = event .data .required_action
209
214
run_requires_action_event = event
210
- if required_action .submit_tool_outputs :
215
+ if required_action and required_action .submit_tool_outputs :
211
216
break
212
217
213
218
if isinstance (event , ThreadRunCompleted ):
@@ -221,45 +226,50 @@ async def handle_assistant_stream(
221
226
"run_requires_action_event" : run_requires_action_event
222
227
}
223
228
224
- async def event_generator ():
229
+ async def event_generator () -> AsyncGenerator [ str , None ] :
225
230
"""
226
231
Main generator for SSE events. We call our helper function to handle the assistant
227
232
stream, and if the assistant requests a tool call, we do it and then re-run the stream.
228
233
"""
229
- step_id = 0
230
- initial_manager = client .beta .threads .runs .stream (
234
+ step_id : str = ""
235
+ stream_manager : AsyncAssistantStreamManager [ AsyncAssistantEventHandler ] = client .beta .threads .runs .stream (
231
236
assistant_id = assistant_id ,
232
237
thread_id = thread_id ,
233
238
parallel_tool_calls = False
234
239
)
235
240
236
- stream_manager = initial_manager
237
241
while True :
242
+ event : dict [str , Any ] | str
238
243
async for event in handle_assistant_stream (templates , logger , stream_manager , step_id ):
239
244
# Detect the special "metadata" event at the end of the generator
240
245
if isinstance (event , dict ) and event .get ("type" ) == "metadata" :
241
- required_action : RequiredAction | None = event [ "required_action" ]
242
- step_id : int = event [ "step_id" ]
243
- run_requires_action_event : ThreadRunRequiresAction | None = event [ "run_requires_action_event" ]
246
+ required_action = cast ( Optional [ RequiredAction ], event . get ( "required_action" ))
247
+ step_id = cast ( str , event . get ( "step_id" , "" ))
248
+ run_requires_action_event = cast ( Optional [ ThreadRunRequiresAction ], event . get ( "run_requires_action_event" ))
244
249
245
250
# If the assistant still needs a tool call, do it and then re-stream
246
- if required_action and required_action .submit_tool_outputs :
251
+ if required_action and required_action .submit_tool_outputs and required_action . submit_tool_outputs . tool_calls :
247
252
for tool_call in required_action .submit_tool_outputs .tool_calls :
248
253
if tool_call .type == "function" :
249
254
try :
250
255
args = json .loads (tool_call .function .arguments )
251
256
location = args .get ("location" , "Unknown" )
252
- dates = args .get ("dates" , [datetime .today ()])
257
+ dates_raw = args .get ("dates" , [datetime .today ().strftime ("%Y-%m-%d" )])
258
+ dates = [
259
+ datetime .strptime (d , "%Y-%m-%d" ) if isinstance (d , str ) else d
260
+ for d in dates_raw
261
+ ]
253
262
except Exception as err :
254
263
logger .error (f"Failed to parse function arguments: { err } " )
255
264
location = "Unknown"
265
+ dates = [datetime .today ()]
256
266
257
267
try :
258
- weather_output : list [ dict ] = get_weather (location , dates )
268
+ weather_output : list = get_weather (location , dates )
259
269
logger .info (f"Weather output: { weather_output } " )
260
270
261
271
# Render the weather widget
262
- weather_widget_html : str = templates .get_template (
272
+ weather_widget_html = templates .get_template (
263
273
"components/weather-widget.html"
264
274
).render (
265
275
reports = weather_output
@@ -273,7 +283,7 @@ async def event_generator():
273
283
"output" : str (weather_output ),
274
284
"tool_call_id" : tool_call .id
275
285
},
276
- "runId" : run_requires_action_event .data .id ,
286
+ "runId" : run_requires_action_event .data .id if run_requires_action_event else "" ,
277
287
}
278
288
except Exception as err :
279
289
error_message = f"Failed to get weather output: { err } "
@@ -284,24 +294,24 @@ async def event_generator():
284
294
"output" : error_message ,
285
295
"tool_call_id" : tool_call .id
286
296
},
287
- "runId" : run_requires_action_event .data .id ,
297
+ "runId" : run_requires_action_event .data .id if run_requires_action_event else "" ,
288
298
}
289
299
290
- # Afterwards, create a fresh stream_manager for the next iteration
291
- new_stream_manager : AsyncAssistantStreamManager = await post_tool_outputs (
292
- client ,
293
- data_for_tool ,
294
- thread_id
295
- )
296
- stream_manager = new_stream_manager
297
- # proceed to rerun the loop
298
- break
300
+ # Afterwards, create a fresh stream_manager for the next iteration
301
+ new_stream_manager = await post_tool_outputs (
302
+ client ,
303
+ data_for_tool ,
304
+ thread_id
305
+ )
306
+ stream_manager = new_stream_manager
307
+ # proceed to rerun the loop
308
+ break
299
309
else :
300
310
# No more tool calls needed; we're done streaming
301
311
return
302
312
else :
303
313
# Normal SSE events: yield them to the client
304
- yield event
314
+ yield str ( event )
305
315
306
316
return StreamingResponse (
307
317
event_generator (),
0 commit comments