1
1
import logging
2
2
import time
3
- from typing import Any
3
+ from typing import Any , AsyncGenerator
4
4
from fastapi .templating import Jinja2Templates
5
5
from fastapi import APIRouter , Form , Depends , Request
6
6
from fastapi .responses import StreamingResponse , HTMLResponse
16
16
from pydantic import BaseModel
17
17
import json
18
18
19
-
20
-
21
-
22
- # Import our get_weather method
23
19
from utils .weather import get_weather
24
20
from utils .sse import sse_format
25
21
@@ -104,19 +100,31 @@ async def stream_response(
104
100
thread_id : str ,
105
101
client : AsyncOpenAI = Depends (lambda : AsyncOpenAI ())
106
102
) -> StreamingResponse :
103
+ """
104
+ Streams the assistant response via Server-Sent Events (SSE). If the assistant requires
105
+ a tool call, we capture that action, invoke the tool, and then re-run the stream
106
+ until completion. This is done in a DRY way by extracting the streaming logic
107
+ into a helper function.
108
+ """
107
109
108
- async def event_generator ():
109
- step_counter : int = 0
110
+ async def handle_assistant_stream (
111
+ templates : Jinja2Templates ,
112
+ logger : logging .Logger ,
113
+ stream_manager : AsyncAssistantStreamManager ,
114
+ start_step_count : int = 0
115
+ ) -> AsyncGenerator :
116
+ """
117
+ Async generator to yield SSE events.
118
+ We yield a final 'metadata' dictionary event once we're done.
119
+ """
120
+ step_counter : int = start_step_count
110
121
required_action : RequiredAction | None = None
111
- stream_manager : AsyncAssistantStreamManager = client .beta .threads .runs .stream (
112
- assistant_id = assistant_id ,
113
- thread_id = thread_id
114
- )
122
+ run_requires_action_event : ThreadRunRequiresAction | None = None
115
123
116
124
async with stream_manager as event_handler :
117
125
async for event in event_handler :
118
126
logger .info (f"{ event } " )
119
-
127
+
120
128
if isinstance (event , ThreadMessageCreated ):
121
129
step_counter += 1
122
130
@@ -127,7 +135,7 @@ async def event_generator():
127
135
stream_name = f"textDelta{ step_counter } "
128
136
)
129
137
)
130
- time .sleep (0.25 ) # Give the client time to render the message
138
+ time .sleep (0.25 ) # Give the client time to render the message
131
139
132
140
if isinstance (event , ThreadMessageDelta ):
133
141
logger .info (f"Sending delta with name textDelta{ step_counter } " )
@@ -136,56 +144,108 @@ async def event_generator():
136
144
event .data .delta .content [0 ].text .value
137
145
)
138
146
139
-
140
147
if isinstance (event , ThreadRunStepCreated ) and event .data .type == "tool_calls" :
141
148
yield sse_format (
142
149
f"toolCallCreated" ,
143
150
templates .get_template ('components/assistant-step.html' ).render (
144
- step_type = 'toolCall' , stream_name = f'toolDelta{ step_counter } '
151
+ step_type = 'toolCall' ,
152
+ stream_name = f'toolDelta{ step_counter } '
145
153
)
146
154
)
147
155
148
- if isinstance (event , ThreadRunStepDelta ) and event .data .type == "tool_calls" :
156
+ if isinstance (event , ThreadRunStepDelta ) and event .data .delta . step_details . type == "tool_calls" :
149
157
if event .data .delta .step_details .tool_calls [0 ].function .name :
150
158
yield sse_format (
151
159
f"toolDelta{ step_counter } " ,
152
- event .data .delta .step_details .tool_calls [0 ].function .name + "\n "
160
+ event .data .delta .step_details .tool_calls [0 ].function .name + "<br> "
153
161
)
154
162
elif event .data .delta .step_details .tool_calls [0 ].function .arguments :
155
163
yield sse_format (
156
164
f"toolDelta{ step_counter } " ,
157
165
event .data .delta .step_details .tool_calls [0 ].function .arguments
158
166
)
159
167
168
+ # If the assistant run requires an action (a tool call), break and handle it
160
169
if isinstance (event , ThreadRunRequiresAction ):
161
170
required_action = event .data .required_action
162
- if required_action and required_action . submit_tool_outputs :
163
- # Exit the for loop and context manager
171
+ run_requires_action_event = event
172
+ if required_action . submit_tool_outputs :
164
173
break
165
174
166
175
if isinstance (event , ThreadRunCompleted ):
167
176
yield sse_format ("endStream" , "DONE" )
168
-
169
- if required_action and required_action .submit_tool_outputs :
170
- # Get the weather
171
- for tool_call in required_action .submit_tool_outputs .tool_calls :
172
- try :
173
- args = json .loads (tool_call .function .arguments )
174
- location = args .get ("location" , "Unknown" )
175
- except Exception as err :
176
- logger .error (f"Failed to parse function arguments: { err } " )
177
- location = "Unknown"
178
-
179
- weather_output = get_weather (location )
180
- logger .info (f"Weather output: { weather_output } " )
181
-
182
- data_for_tool = {
183
- "tool_outputs" : weather_output ,
184
- "runId" : event .data .id ,
185
- }
186
- stream_manager : AsyncAssistantStreamManager = await post_tool_outputs (client , data_for_tool , thread_id )
187
-
188
- # We here need to run the whole stream management loop again
177
+
178
+ # At the end (or break) of this async generator, we yield a final "metadata" object
179
+ yield {
180
+ "type" : "metadata" ,
181
+ "required_action" : required_action ,
182
+ "step_counter" : step_counter ,
183
+ "run_requires_action_event" : run_requires_action_event
184
+ }
185
+
186
+ async def event_generator ():
187
+ """
188
+ Main generator for SSE events. We call our helper function to handle the assistant
189
+ stream, and if the assistant requests a tool call, we do it and then re-run the stream.
190
+ """
191
+ step_counter = 0
192
+ # First run of the assistant stream
193
+ initial_manager = client .beta .threads .runs .stream (
194
+ assistant_id = assistant_id ,
195
+ thread_id = thread_id
196
+ )
197
+
198
+ # We'll re-run the loop if needed for tool calls
199
+ stream_manager = initial_manager
200
+ while True :
201
+ async for event in handle_assistant_stream (templates , logger , stream_manager , step_counter ):
202
+ # Detect the special "metadata" event at the end of the generator
203
+ if isinstance (event , dict ) and event .get ("type" ) == "metadata" :
204
+ required_action : RequiredAction | None = event ["required_action" ]
205
+ step_counter : int = event ["step_counter" ]
206
+ run_requires_action_event : ThreadRunRequiresAction | None = event ["run_requires_action_event" ]
207
+
208
+ # If the assistant still needs a tool call, do it and then re-stream
209
+ if required_action and required_action .submit_tool_outputs :
210
+ for tool_call in required_action .submit_tool_outputs .tool_calls :
211
+ yield (
212
+ f"event: toolCallCreated\n "
213
+ f"data: { templates .get_template ('components/assistant-step.html' ).render (
214
+ step_type = 'toolCall' , stream_name = f'toolDelta{ step_counter } '
215
+ ).replace ('\n ' , '' )} \n \n "
216
+ )
217
+
218
+ if tool_call .type == "function" and tool_call .function .name == "get_weather" :
219
+ try :
220
+ args = json .loads (tool_call .function .arguments )
221
+ location = args .get ("location" , "Unknown" )
222
+ except Exception as err :
223
+ logger .error (f"Failed to parse function arguments: { err } " )
224
+ location = "Unknown"
225
+
226
+ weather_output = get_weather (location )
227
+ logger .info (f"Weather output: { weather_output } " )
228
+
229
+ data_for_tool = {
230
+ "tool_outputs" : weather_output ,
231
+ "runId" : event .data .id ,
232
+ }
233
+
234
+ # Afterwards, create a fresh stream_manager for the next iteration
235
+ new_stream_manager : AsyncAssistantStreamManager = await post_tool_outputs (
236
+ client ,
237
+ data_for_tool ,
238
+ thread_id
239
+ )
240
+ stream_manager = new_stream_manager
241
+ # proceed to rerun the loop
242
+ break
243
+ else :
244
+ # No more tool calls needed; we're done streaming
245
+ return
246
+ else :
247
+ # Normal SSE events: yield them to the client
248
+ yield event
189
249
190
250
return StreamingResponse (
191
251
event_generator (),
0 commit comments