1
1
import logging
2
2
import time
3
- from typing import Any
3
+ from typing import Any , AsyncGenerator
4
4
from fastapi .templating import Jinja2Templates
5
5
from fastapi import APIRouter , Form , Depends , Request
6
6
from fastapi .responses import StreamingResponse , HTMLResponse
7
7
from openai import AsyncOpenAI
8
8
from openai .resources .beta .threads .runs .runs import AsyncAssistantStreamManager
9
- from openai .types .beta .assistant_stream_event import ThreadMessageCreated , ThreadMessageDelta , ThreadRunCompleted , ThreadRunRequiresAction
9
+ from openai .types .beta .assistant_stream_event import (
10
+ ThreadMessageCreated , ThreadMessageDelta , ThreadRunCompleted ,
11
+ ThreadRunRequiresAction , ThreadRunStepCreated , ThreadRunStepDelta
12
+ )
13
+ from openai .types .beta .threads .run import RequiredAction
10
14
from fastapi .responses import StreamingResponse
11
15
from fastapi import APIRouter , Depends , Form , HTTPException
12
16
from pydantic import BaseModel
13
17
import json
14
18
15
- # Import our get_weather method
16
19
from utils .weather import get_weather
20
+ from utils .sse import sse_format
17
21
18
22
logger : logging .Logger = logging .getLogger ("uvicorn.error" )
19
23
logger .setLevel (logging .DEBUG )
@@ -95,41 +99,113 @@ async def stream_response(
95
99
assistant_id : str ,
96
100
thread_id : str ,
97
101
client : AsyncOpenAI = Depends (lambda : AsyncOpenAI ())
98
- ) -> StreamingResponse :
99
-
100
- # Create a generator to stream the response from the assistant
101
- async def event_generator ():
102
- step_counter : int = 0
103
- stream_manager : AsyncAssistantStreamManager = client .beta .threads .runs .stream (
104
- assistant_id = assistant_id ,
105
- thread_id = thread_id
106
- )
102
+ ) -> StreamingResponse :
103
+ """
104
+ Streams the assistant response via Server-Sent Events (SSE). If the assistant requires
105
+ a tool call, we capture that action, invoke the tool, and then re-run the stream
106
+ until completion. This is done in a DRY way by extracting the streaming logic
107
+ into a helper function.
108
+ """
109
+
110
+ async def handle_assistant_stream (
111
+ templates : Jinja2Templates ,
112
+ logger : logging .Logger ,
113
+ stream_manager : AsyncAssistantStreamManager ,
114
+ start_step_count : int = 0
115
+ ) -> AsyncGenerator :
116
+ """
117
+ Async generator to yield SSE events.
118
+ We yield a final 'metadata' dictionary event once we're done.
119
+ """
120
+ step_counter : int = start_step_count
121
+ required_action : RequiredAction | None = None
122
+ run_requires_action_event : ThreadRunRequiresAction | None = None
107
123
108
124
async with stream_manager as event_handler :
109
125
async for event in event_handler :
110
126
logger .info (f"{ event } " )
111
-
127
+
112
128
if isinstance (event , ThreadMessageCreated ):
113
129
step_counter += 1
114
130
115
- yield (
116
- f"event: messageCreated\n "
117
- f"data: { templates .get_template ("components/assistant-step.html" ).render (
118
- step_type = f "assistantMessage" ,
131
+ yield sse_format (
132
+ " messageCreated" ,
133
+ templates .get_template ("components/assistant-step.html" ).render (
134
+ step_type = "assistantMessage" ,
119
135
stream_name = f"textDelta{ step_counter } "
120
- ). replace ( " \n " , "" ) } \n \n "
136
+ )
121
137
)
122
- time .sleep (0.25 ) # Give the client time to render the message
138
+ time .sleep (0.25 ) # Give the client time to render the message
123
139
124
140
if isinstance (event , ThreadMessageDelta ):
125
141
logger .info (f"Sending delta with name textDelta{ step_counter } " )
126
- yield (
127
- f"event: textDelta{ step_counter } \n "
128
- f"data: { event .data .delta .content [0 ].text .value } \n \n "
142
+ yield sse_format (
143
+ f"textDelta{ step_counter } " ,
144
+ event .data .delta .content [0 ].text .value
145
+ )
146
+
147
+ if isinstance (event , ThreadRunStepCreated ) and event .data .type == "tool_calls" :
148
+ yield sse_format (
149
+ f"toolCallCreated" ,
150
+ templates .get_template ('components/assistant-step.html' ).render (
151
+ step_type = 'toolCall' ,
152
+ stream_name = f'toolDelta{ step_counter } '
153
+ )
129
154
)
130
155
156
+ if isinstance (event , ThreadRunStepDelta ) and event .data .delta .step_details .type == "tool_calls" :
157
+ if event .data .delta .step_details .tool_calls [0 ].function .name :
158
+ yield sse_format (
159
+ f"toolDelta{ step_counter } " ,
160
+ event .data .delta .step_details .tool_calls [0 ].function .name + "<br>"
161
+ )
162
+ elif event .data .delta .step_details .tool_calls [0 ].function .arguments :
163
+ yield sse_format (
164
+ f"toolDelta{ step_counter } " ,
165
+ event .data .delta .step_details .tool_calls [0 ].function .arguments
166
+ )
167
+
168
+ # If the assistant run requires an action (a tool call), break and handle it
131
169
if isinstance (event , ThreadRunRequiresAction ):
132
170
required_action = event .data .required_action
171
+ run_requires_action_event = event
172
+ if required_action .submit_tool_outputs :
173
+ break
174
+
175
+ if isinstance (event , ThreadRunCompleted ):
176
+ yield sse_format ("endStream" , "DONE" )
177
+
178
+ # At the end (or break) of this async generator, we yield a final "metadata" object
179
+ yield {
180
+ "type" : "metadata" ,
181
+ "required_action" : required_action ,
182
+ "step_counter" : step_counter ,
183
+ "run_requires_action_event" : run_requires_action_event
184
+ }
185
+
186
+ async def event_generator ():
187
+ """
188
+ Main generator for SSE events. We call our helper function to handle the assistant
189
+ stream, and if the assistant requests a tool call, we do it and then re-run the stream.
190
+ """
191
+ step_counter = 0
192
+ # First run of the assistant stream
193
+ initial_manager = client .beta .threads .runs .stream (
194
+ assistant_id = assistant_id ,
195
+ thread_id = thread_id
196
+ )
197
+
198
+ # We'll re-run the loop if needed for tool calls
199
+ stream_manager = initial_manager
200
+ while True :
201
+ async for event in handle_assistant_stream (templates , logger , stream_manager , step_counter ):
202
+ # Detect the special "metadata" event at the end of the generator
203
+ if isinstance (event , dict ) and event .get ("type" ) == "metadata" :
204
+ required_action : RequiredAction | None = event ["required_action" ]
205
+ step_counter : int = event ["step_counter" ]
206
+ run_requires_action_event : ThreadRunRequiresAction | None = event ["run_requires_action_event" ]
207
+
208
+ # If the assistant still needs a tool call, do it and then re-stream
133
209
if required_action and required_action .submit_tool_outputs :
134
210
for tool_call in required_action .submit_tool_outputs .tool_calls :
135
211
yield (
@@ -154,14 +230,22 @@ async def event_generator():
154
230
"tool_outputs" : weather_output ,
155
231
"runId" : event .data .id ,
156
232
}
157
- await post_tool_outputs (client , data_for_tool , thread_id )
158
-
159
- if isinstance (event , ThreadRunCompleted ):
160
- yield "event: endStream\n data: DONE\n \n "
161
-
162
- # Send a done event when the stream is complete
163
- yield "event: endStream\n data: DONE\n \n "
164
-
233
+
234
+ # Afterwards, create a fresh stream_manager for the next iteration
235
+ new_stream_manager : AsyncAssistantStreamManager = await post_tool_outputs (
236
+ client ,
237
+ data_for_tool ,
238
+ thread_id
239
+ )
240
+ stream_manager = new_stream_manager
241
+ # proceed to rerun the loop
242
+ break
243
+ else :
244
+ # No more tool calls needed; we're done streaming
245
+ return
246
+ else :
247
+ # Normal SSE events: yield them to the client
248
+ yield event
165
249
166
250
return StreamingResponse (
167
251
event_generator (),
0 commit comments