Skip to content

Commit 67fe090

Browse files
Revert "Checkpoint"
This reverts commit 7971473.
1 parent 7971473 commit 67fe090

File tree

1 file changed

+100
-40
lines changed

1 file changed

+100
-40
lines changed

routers/chat.py

Lines changed: 100 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import time
3-
from typing import Any
3+
from typing import Any, AsyncGenerator
44
from fastapi.templating import Jinja2Templates
55
from fastapi import APIRouter, Form, Depends, Request
66
from fastapi.responses import StreamingResponse, HTMLResponse
@@ -16,10 +16,6 @@
1616
from pydantic import BaseModel
1717
import json
1818

19-
20-
21-
22-
# Import our get_weather method
2319
from utils.weather import get_weather
2420
from utils.sse import sse_format
2521

@@ -104,19 +100,31 @@ async def stream_response(
104100
thread_id: str,
105101
client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())
106102
) -> StreamingResponse:
103+
"""
104+
Streams the assistant response via Server-Sent Events (SSE). If the assistant requires
105+
a tool call, we capture that action, invoke the tool, and then re-run the stream
106+
until completion. This is done in a DRY way by extracting the streaming logic
107+
into a helper function.
108+
"""
107109

108-
async def event_generator():
109-
step_counter: int = 0
110+
async def handle_assistant_stream(
111+
templates: Jinja2Templates,
112+
logger: logging.Logger,
113+
stream_manager: AsyncAssistantStreamManager,
114+
start_step_count: int = 0
115+
) -> AsyncGenerator:
116+
"""
117+
Async generator to yield SSE events.
118+
We yield a final 'metadata' dictionary event once we're done.
119+
"""
120+
step_counter: int = start_step_count
110121
required_action: RequiredAction | None = None
111-
stream_manager: AsyncAssistantStreamManager = client.beta.threads.runs.stream(
112-
assistant_id=assistant_id,
113-
thread_id=thread_id
114-
)
122+
run_requires_action_event: ThreadRunRequiresAction | None = None
115123

116124
async with stream_manager as event_handler:
117125
async for event in event_handler:
118126
logger.info(f"{event}")
119-
127+
120128
if isinstance(event, ThreadMessageCreated):
121129
step_counter += 1
122130

@@ -127,7 +135,7 @@ async def event_generator():
127135
stream_name=f"textDelta{step_counter}"
128136
)
129137
)
130-
time.sleep(0.25) # Give the client time to render the message
138+
time.sleep(0.25) # Give the client time to render the message
131139

132140
if isinstance(event, ThreadMessageDelta):
133141
logger.info(f"Sending delta with name textDelta{step_counter}")
@@ -136,56 +144,108 @@ async def event_generator():
136144
event.data.delta.content[0].text.value
137145
)
138146

139-
140147
if isinstance(event, ThreadRunStepCreated) and event.data.type == "tool_calls":
141148
yield sse_format(
142149
f"toolCallCreated",
143150
templates.get_template('components/assistant-step.html').render(
144-
step_type='toolCall', stream_name=f'toolDelta{step_counter}'
151+
step_type='toolCall',
152+
stream_name=f'toolDelta{step_counter}'
145153
)
146154
)
147155

148-
if isinstance(event, ThreadRunStepDelta) and event.data.type == "tool_calls":
156+
if isinstance(event, ThreadRunStepDelta) and event.data.delta.step_details.type == "tool_calls":
149157
if event.data.delta.step_details.tool_calls[0].function.name:
150158
yield sse_format(
151159
f"toolDelta{step_counter}",
152-
event.data.delta.step_details.tool_calls[0].function.name + "\n"
160+
event.data.delta.step_details.tool_calls[0].function.name + "<br>"
153161
)
154162
elif event.data.delta.step_details.tool_calls[0].function.arguments:
155163
yield sse_format(
156164
f"toolDelta{step_counter}",
157165
event.data.delta.step_details.tool_calls[0].function.arguments
158166
)
159167

168+
# If the assistant run requires an action (a tool call), break and handle it
160169
if isinstance(event, ThreadRunRequiresAction):
161170
required_action = event.data.required_action
162-
if required_action and required_action.submit_tool_outputs:
163-
# Exit the for loop and context manager
171+
run_requires_action_event = event
172+
if required_action.submit_tool_outputs:
164173
break
165174

166175
if isinstance(event, ThreadRunCompleted):
167176
yield sse_format("endStream", "DONE")
168-
169-
if required_action and required_action.submit_tool_outputs:
170-
# Get the weather
171-
for tool_call in required_action.submit_tool_outputs.tool_calls:
172-
try:
173-
args = json.loads(tool_call.function.arguments)
174-
location = args.get("location", "Unknown")
175-
except Exception as err:
176-
logger.error(f"Failed to parse function arguments: {err}")
177-
location = "Unknown"
178-
179-
weather_output = get_weather(location)
180-
logger.info(f"Weather output: {weather_output}")
181-
182-
data_for_tool = {
183-
"tool_outputs": weather_output,
184-
"runId": event.data.id,
185-
}
186-
stream_manager: AsyncAssistantStreamManager = await post_tool_outputs(client, data_for_tool, thread_id)
187-
188-
# We here need to run the whole stream management loop again
177+
178+
# At the end (or break) of this async generator, we yield a final "metadata" object
179+
yield {
180+
"type": "metadata",
181+
"required_action": required_action,
182+
"step_counter": step_counter,
183+
"run_requires_action_event": run_requires_action_event
184+
}
185+
186+
async def event_generator():
187+
"""
188+
Main generator for SSE events. We call our helper function to handle the assistant
189+
stream, and if the assistant requests a tool call, we do it and then re-run the stream.
190+
"""
191+
step_counter = 0
192+
# First run of the assistant stream
193+
initial_manager = client.beta.threads.runs.stream(
194+
assistant_id=assistant_id,
195+
thread_id=thread_id
196+
)
197+
198+
# We'll re-run the loop if needed for tool calls
199+
stream_manager = initial_manager
200+
while True:
201+
async for event in handle_assistant_stream(templates, logger, stream_manager, step_counter):
202+
# Detect the special "metadata" event at the end of the generator
203+
if isinstance(event, dict) and event.get("type") == "metadata":
204+
required_action: RequiredAction | None = event["required_action"]
205+
step_counter: int = event["step_counter"]
206+
run_requires_action_event: ThreadRunRequiresAction | None = event["run_requires_action_event"]
207+
208+
# If the assistant still needs a tool call, do it and then re-stream
209+
if required_action and required_action.submit_tool_outputs:
210+
for tool_call in required_action.submit_tool_outputs.tool_calls:
211+
yield (
212+
f"event: toolCallCreated\n"
213+
f"data: {templates.get_template('components/assistant-step.html').render(
214+
step_type='toolCall', stream_name=f'toolDelta{step_counter}'
215+
).replace('\n', '')}\n\n"
216+
)
217+
218+
if tool_call.type == "function" and tool_call.function.name == "get_weather":
219+
try:
220+
args = json.loads(tool_call.function.arguments)
221+
location = args.get("location", "Unknown")
222+
except Exception as err:
223+
logger.error(f"Failed to parse function arguments: {err}")
224+
location = "Unknown"
225+
226+
weather_output = get_weather(location)
227+
logger.info(f"Weather output: {weather_output}")
228+
229+
data_for_tool = {
230+
"tool_outputs": weather_output,
231+
"runId": event.data.id,
232+
}
233+
234+
# Afterwards, create a fresh stream_manager for the next iteration
235+
new_stream_manager: AsyncAssistantStreamManager = await post_tool_outputs(
236+
client,
237+
data_for_tool,
238+
thread_id
239+
)
240+
stream_manager = new_stream_manager
241+
# proceed to rerun the loop
242+
break
243+
else:
244+
# No more tool calls needed; we're done streaming
245+
return
246+
else:
247+
# Normal SSE events: yield them to the client
248+
yield event
189249

190250
return StreamingResponse(
191251
event_generator(),

0 commit comments

Comments
 (0)