1
1
import logging
2
2
import time
3
- from typing import Any , AsyncGenerator
3
+ from typing import Any
4
4
from fastapi .templating import Jinja2Templates
5
5
from fastapi import APIRouter , Form , Depends , Request
6
6
from fastapi .responses import StreamingResponse , HTMLResponse
16
16
from pydantic import BaseModel
17
17
import json
18
18
19
+
20
+
21
+
22
+ # Import our get_weather method
19
23
from utils .weather import get_weather
20
24
from utils .sse import sse_format
21
25
@@ -100,31 +104,19 @@ async def stream_response(
100
104
thread_id : str ,
101
105
client : AsyncOpenAI = Depends (lambda : AsyncOpenAI ())
102
106
) -> StreamingResponse :
103
- """
104
- Streams the assistant response via Server-Sent Events (SSE). If the assistant requires
105
- a tool call, we capture that action, invoke the tool, and then re-run the stream
106
- until completion. This is done in a DRY way by extracting the streaming logic
107
- into a helper function.
108
- """
109
107
110
- async def handle_assistant_stream (
111
- templates : Jinja2Templates ,
112
- logger : logging .Logger ,
113
- stream_manager : AsyncAssistantStreamManager ,
114
- start_step_count : int = 0
115
- ) -> AsyncGenerator :
116
- """
117
- Async generator to yield SSE events.
118
- We yield a final 'metadata' dictionary event once we're done.
119
- """
120
- step_counter : int = start_step_count
108
+ async def event_generator ():
109
+ step_counter : int = 0
121
110
required_action : RequiredAction | None = None
122
- run_requires_action_event : ThreadRunRequiresAction | None = None
111
+ stream_manager : AsyncAssistantStreamManager = client .beta .threads .runs .stream (
112
+ assistant_id = assistant_id ,
113
+ thread_id = thread_id
114
+ )
123
115
124
116
async with stream_manager as event_handler :
125
117
async for event in event_handler :
126
118
logger .info (f"{ event } " )
127
-
119
+
128
120
if isinstance (event , ThreadMessageCreated ):
129
121
step_counter += 1
130
122
@@ -135,7 +127,7 @@ async def handle_assistant_stream(
135
127
stream_name = f"textDelta{ step_counter } "
136
128
)
137
129
)
138
- time .sleep (0.25 ) # Give the client time to render the message
130
+ time .sleep (0.25 ) # Give the client time to render the message
139
131
140
132
if isinstance (event , ThreadMessageDelta ):
141
133
logger .info (f"Sending delta with name textDelta{ step_counter } " )
@@ -144,108 +136,56 @@ async def handle_assistant_stream(
144
136
event .data .delta .content [0 ].text .value
145
137
)
146
138
139
+
147
140
if isinstance (event , ThreadRunStepCreated ) and event .data .type == "tool_calls" :
148
141
yield sse_format (
149
142
f"toolCallCreated" ,
150
143
templates .get_template ('components/assistant-step.html' ).render (
151
- step_type = 'toolCall' ,
152
- stream_name = f'toolDelta{ step_counter } '
144
+ step_type = 'toolCall' , stream_name = f'toolDelta{ step_counter } '
153
145
)
154
146
)
155
147
156
- if isinstance (event , ThreadRunStepDelta ) and event .data .delta . step_details . type == "tool_calls" :
148
+ if isinstance (event , ThreadRunStepDelta ) and event .data .type == "tool_calls" :
157
149
if event .data .delta .step_details .tool_calls [0 ].function .name :
158
150
yield sse_format (
159
151
f"toolDelta{ step_counter } " ,
160
- event .data .delta .step_details .tool_calls [0 ].function .name + "<br> "
152
+ event .data .delta .step_details .tool_calls [0 ].function .name + "\n "
161
153
)
162
154
elif event .data .delta .step_details .tool_calls [0 ].function .arguments :
163
155
yield sse_format (
164
156
f"toolDelta{ step_counter } " ,
165
157
event .data .delta .step_details .tool_calls [0 ].function .arguments
166
158
)
167
159
168
- # If the assistant run requires an action (a tool call), break and handle it
169
160
if isinstance (event , ThreadRunRequiresAction ):
170
161
required_action = event .data .required_action
171
- run_requires_action_event = event
172
- if required_action . submit_tool_outputs :
162
+ if required_action and required_action . submit_tool_outputs :
163
+ # Exit the for loop and context manager
173
164
break
174
165
175
166
if isinstance (event , ThreadRunCompleted ):
176
167
yield sse_format ("endStream" , "DONE" )
177
-
178
- # At the end (or break) of this async generator, we yield a final "metadata" object
179
- yield {
180
- "type" : "metadata" ,
181
- "required_action" : required_action ,
182
- "step_counter" : step_counter ,
183
- "run_requires_action_event" : run_requires_action_event
184
- }
185
-
186
- async def event_generator ():
187
- """
188
- Main generator for SSE events. We call our helper function to handle the assistant
189
- stream, and if the assistant requests a tool call, we do it and then re-run the stream.
190
- """
191
- step_counter = 0
192
- # First run of the assistant stream
193
- initial_manager = client .beta .threads .runs .stream (
194
- assistant_id = assistant_id ,
195
- thread_id = thread_id
196
- )
197
-
198
- # We'll re-run the loop if needed for tool calls
199
- stream_manager = initial_manager
200
- while True :
201
- async for event in handle_assistant_stream (templates , logger , stream_manager , step_counter ):
202
- # Detect the special "metadata" event at the end of the generator
203
- if isinstance (event , dict ) and event .get ("type" ) == "metadata" :
204
- required_action : RequiredAction | None = event ["required_action" ]
205
- step_counter : int = event ["step_counter" ]
206
- run_requires_action_event : ThreadRunRequiresAction | None = event ["run_requires_action_event" ]
207
-
208
- # If the assistant still needs a tool call, do it and then re-stream
209
- if required_action and required_action .submit_tool_outputs :
210
- for tool_call in required_action .submit_tool_outputs .tool_calls :
211
- yield (
212
- f"event: toolCallCreated\n "
213
- f"data: { templates .get_template ('components/assistant-step.html' ).render (
214
- step_type = 'toolCall' , stream_name = f'toolDelta{ step_counter } '
215
- ).replace ('\n ' , '' )} \n \n "
216
- )
217
-
218
- if tool_call .type == "function" and tool_call .function .name == "get_weather" :
219
- try :
220
- args = json .loads (tool_call .function .arguments )
221
- location = args .get ("location" , "Unknown" )
222
- except Exception as err :
223
- logger .error (f"Failed to parse function arguments: { err } " )
224
- location = "Unknown"
225
-
226
- weather_output = get_weather (location )
227
- logger .info (f"Weather output: { weather_output } " )
228
-
229
- data_for_tool = {
230
- "tool_outputs" : weather_output ,
231
- "runId" : event .data .id ,
232
- }
233
-
234
- # Afterwards, create a fresh stream_manager for the next iteration
235
- new_stream_manager : AsyncAssistantStreamManager = await post_tool_outputs (
236
- client ,
237
- data_for_tool ,
238
- thread_id
239
- )
240
- stream_manager = new_stream_manager
241
- # proceed to rerun the loop
242
- break
243
- else :
244
- # No more tool calls needed; we're done streaming
245
- return
246
- else :
247
- # Normal SSE events: yield them to the client
248
- yield event
168
+
169
+ if required_action and required_action .submit_tool_outputs :
170
+ # Get the weather
171
+ for tool_call in required_action .submit_tool_outputs .tool_calls :
172
+ try :
173
+ args = json .loads (tool_call .function .arguments )
174
+ location = args .get ("location" , "Unknown" )
175
+ except Exception as err :
176
+ logger .error (f"Failed to parse function arguments: { err } " )
177
+ location = "Unknown"
178
+
179
+ weather_output = get_weather (location )
180
+ logger .info (f"Weather output: { weather_output } " )
181
+
182
+ data_for_tool = {
183
+ "tool_outputs" : weather_output ,
184
+ "runId" : event .data .id ,
185
+ }
186
+ stream_manager : AsyncAssistantStreamManager = await post_tool_outputs (client , data_for_tool , thread_id )
187
+
188
+ # We here need to run the whole stream management loop again
249
189
250
190
return StreamingResponse (
251
191
event_generator (),
0 commit comments