11
11
ThreadMessageCreated , ThreadMessageDelta , ThreadRunCompleted ,
12
12
ThreadRunRequiresAction , ThreadRunStepCreated , ThreadRunStepDelta
13
13
)
14
+ from openai .types .beta import AssistantStreamEvent
15
+ from openai .lib .streaming ._assistants import AsyncAssistantEventHandler
14
16
from openai .types .beta .threads .run_submit_tool_outputs_params import ToolOutput
15
17
from openai .types .beta .threads .run import RequiredAction
16
18
from fastapi .responses import StreamingResponse
@@ -122,45 +124,45 @@ async def handle_assistant_stream(
122
124
templates : Jinja2Templates ,
123
125
logger : logging .Logger ,
124
126
stream_manager : AsyncAssistantStreamManager ,
125
- start_step_count : int = 0
127
+ step_id : int = 0
126
128
) -> AsyncGenerator :
127
129
"""
128
130
Async generator to yield SSE events.
129
131
We yield a final 'metadata' dictionary event once we're done.
130
132
"""
131
- step_counter : int = start_step_count
132
133
required_action : RequiredAction | None = None
133
134
run_requires_action_event : ThreadRunRequiresAction | None = None
134
135
136
+ event_handler : AsyncAssistantEventHandler
135
137
async with stream_manager as event_handler :
138
+ event : AssistantStreamEvent
136
139
async for event in event_handler :
137
- logger .info (f"{ event } " )
138
-
139
140
if isinstance (event , ThreadMessageCreated ):
140
- step_counter += 1
141
+ step_id = event . data . id
141
142
142
143
yield sse_format (
143
144
"messageCreated" ,
144
145
templates .get_template ("components/assistant-step.html" ).render (
145
146
step_type = "assistantMessage" ,
146
- stream_name = f"textDelta{ step_counter } "
147
+ stream_name = f"textDelta{ step_id } "
147
148
)
148
149
)
149
150
time .sleep (0.25 ) # Give the client time to render the message
150
151
151
152
if isinstance (event , ThreadMessageDelta ):
152
- logger .info (f"Sending delta with name textDelta{ step_counter } " )
153
153
yield sse_format (
154
- f"textDelta{ step_counter } " ,
154
+ f"textDelta{ step_id } " ,
155
155
event .data .delta .content [0 ].text .value
156
156
)
157
157
158
158
if isinstance (event , ThreadRunStepCreated ) and event .data .type == "tool_calls" :
159
+ step_id = event .data .id
160
+
159
161
yield sse_format (
160
162
f"toolCallCreated" ,
161
163
templates .get_template ('components/assistant-step.html' ).render (
162
164
step_type = 'toolCall' ,
163
- stream_name = f'toolDelta{ step_counter } '
165
+ stream_name = f'toolDelta{ step_id } '
164
166
)
165
167
)
166
168
time .sleep (0.25 ) # Give the client time to render the message
@@ -172,32 +174,32 @@ async def handle_assistant_stream(
172
174
if tool_call .type == "function" :
173
175
if tool_call .function .name :
174
176
yield sse_format (
175
- f"toolDelta{ step_counter } " ,
177
+ f"toolDelta{ step_id } " ,
176
178
tool_call .function .name + "<br>"
177
179
)
178
180
elif tool_call .function .arguments :
179
181
yield sse_format (
180
- f"toolDelta{ step_counter } " ,
182
+ f"toolDelta{ step_id } " ,
181
183
tool_call .function .arguments
182
184
)
183
185
184
186
# Handle code interpreter tool calls
185
187
elif tool_call .type == "code_interpreter" :
186
188
if tool_call .code_interpreter .input :
187
189
yield sse_format (
188
- f"toolDelta{ step_counter } " ,
190
+ f"toolDelta{ step_id } " ,
189
191
f"{ tool_call .code_interpreter .input } "
190
192
)
191
193
if tool_call .code_interpreter .outputs :
192
194
for output in tool_call .code_interpreter .outputs :
193
195
if output .type == "logs" :
194
196
yield sse_format (
195
- f"toolDelta{ step_counter } " ,
197
+ f"toolDelta{ step_id } " ,
196
198
f"{ output .logs } "
197
199
)
198
200
elif output .type == "image" :
199
201
yield sse_format (
200
- f"toolDelta{ step_counter } " ,
202
+ f"toolDelta{ step_id } " ,
201
203
f"{ output .image .file_id } "
202
204
)
203
205
@@ -215,7 +217,7 @@ async def handle_assistant_stream(
215
217
yield {
216
218
"type" : "metadata" ,
217
219
"required_action" : required_action ,
218
- "step_counter " : step_counter ,
220
+ "step_id " : step_id ,
219
221
"run_requires_action_event" : run_requires_action_event
220
222
}
221
223
@@ -224,36 +226,26 @@ async def event_generator():
224
226
Main generator for SSE events. We call our helper function to handle the assistant
225
227
stream, and if the assistant requests a tool call, we do it and then re-run the stream.
226
228
"""
227
- step_counter = 0
228
- # First run of the assistant stream
229
+ step_id = 0
229
230
initial_manager = client .beta .threads .runs .stream (
230
231
assistant_id = assistant_id ,
231
232
thread_id = thread_id ,
232
233
parallel_tool_calls = False
233
234
)
234
235
235
- # We'll re-run the loop if needed for tool calls
236
236
stream_manager = initial_manager
237
- while True :
238
- async for event in handle_assistant_stream (templates , logger , stream_manager , step_counter ):
237
+ while True :
238
+ async for event in handle_assistant_stream (templates , logger , stream_manager , step_id ):
239
239
# Detect the special "metadata" event at the end of the generator
240
240
if isinstance (event , dict ) and event .get ("type" ) == "metadata" :
241
241
required_action : RequiredAction | None = event ["required_action" ]
242
- step_counter : int = event ["step_counter " ]
242
+ step_id : int = event ["step_id " ]
243
243
run_requires_action_event : ThreadRunRequiresAction | None = event ["run_requires_action_event" ]
244
244
245
245
# If the assistant still needs a tool call, do it and then re-stream
246
246
if required_action and required_action .submit_tool_outputs :
247
247
for tool_call in required_action .submit_tool_outputs .tool_calls :
248
- yield sse_format (
249
- "toolCallCreated" ,
250
- templates .get_template ('components/assistant-step.html' ).render (
251
- step_type = 'toolCall' ,
252
- stream_name = f'toolDelta{ step_counter } '
253
- )
254
- )
255
-
256
- if tool_call .type == "function" and tool_call .function .name == "get_weather" :
248
+ if tool_call .type == "function" :
257
249
try :
258
250
args = json .loads (tool_call .function .arguments )
259
251
location = args .get ("location" , "Unknown" )
@@ -262,26 +254,38 @@ async def event_generator():
262
254
logger .error (f"Failed to parse function arguments: { err } " )
263
255
location = "Unknown"
264
256
265
- weather_output : list [dict ] = get_weather (location , dates )
266
- logger .info (f"Weather output: { weather_output } " )
267
-
268
- # Render the weather widget
269
- weather_widget_html : str = templates .get_template (
270
- "components/weather-widget.html"
271
- ).render (
272
- reports = weather_output
273
- )
274
-
275
- # Yield the rendered HTML
276
- yield sse_format ("toolOutput" , weather_widget_html )
277
-
278
- data_for_tool = {
279
- "tool_outputs" : {
280
- "output" : str (weather_output ),
281
- "tool_call_id" : tool_call .id
282
- },
283
- "runId" : run_requires_action_event .data .id ,
284
- }
257
+ try :
258
+ weather_output : list [dict ] = get_weather (location , dates )
259
+ logger .info (f"Weather output: { weather_output } " )
260
+
261
+ # Render the weather widget
262
+ weather_widget_html : str = templates .get_template (
263
+ "components/weather-widget.html"
264
+ ).render (
265
+ reports = weather_output
266
+ )
267
+
268
+ # Yield the rendered HTML
269
+ yield sse_format ("toolOutput" , weather_widget_html )
270
+
271
+ data_for_tool = {
272
+ "tool_outputs" : {
273
+ "output" : str (weather_output ),
274
+ "tool_call_id" : tool_call .id
275
+ },
276
+ "runId" : run_requires_action_event .data .id ,
277
+ }
278
+ except Exception as err :
279
+ error_message = f"Failed to get weather output: { err } "
280
+ logger .error (error_message )
281
+ yield sse_format ("toolOutput" , error_message )
282
+ data_for_tool = {
283
+ "tool_outputs" : {
284
+ "output" : error_message ,
285
+ "tool_call_id" : tool_call .id
286
+ },
287
+ "runId" : run_requires_action_event .data .id ,
288
+ }
285
289
286
290
# Afterwards, create a fresh stream_manager for the next iteration
287
291
new_stream_manager : AsyncAssistantStreamManager = await post_tool_outputs (
0 commit comments