Skip to content

Commit 7971473

Browse files
Checkpoint
1 parent fa1eed5 commit 7971473

File tree

1 file changed

+40
-100
lines changed

1 file changed

+40
-100
lines changed

routers/chat.py

Lines changed: 40 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import time
3-
from typing import Any, AsyncGenerator
3+
from typing import Any
44
from fastapi.templating import Jinja2Templates
55
from fastapi import APIRouter, Form, Depends, Request
66
from fastapi.responses import StreamingResponse, HTMLResponse
@@ -16,6 +16,10 @@
1616
from pydantic import BaseModel
1717
import json
1818

19+
20+
21+
22+
# Import our get_weather method
1923
from utils.weather import get_weather
2024
from utils.sse import sse_format
2125

@@ -100,31 +104,19 @@ async def stream_response(
100104
thread_id: str,
101105
client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())
102106
) -> StreamingResponse:
103-
"""
104-
Streams the assistant response via Server-Sent Events (SSE). If the assistant requires
105-
a tool call, we capture that action, invoke the tool, and then re-run the stream
106-
until completion. This is done in a DRY way by extracting the streaming logic
107-
into a helper function.
108-
"""
109107

110-
async def handle_assistant_stream(
111-
templates: Jinja2Templates,
112-
logger: logging.Logger,
113-
stream_manager: AsyncAssistantStreamManager,
114-
start_step_count: int = 0
115-
) -> AsyncGenerator:
116-
"""
117-
Async generator to yield SSE events.
118-
We yield a final 'metadata' dictionary event once we're done.
119-
"""
120-
step_counter: int = start_step_count
108+
async def event_generator():
109+
step_counter: int = 0
121110
required_action: RequiredAction | None = None
122-
run_requires_action_event: ThreadRunRequiresAction | None = None
111+
stream_manager: AsyncAssistantStreamManager = client.beta.threads.runs.stream(
112+
assistant_id=assistant_id,
113+
thread_id=thread_id
114+
)
123115

124116
async with stream_manager as event_handler:
125117
async for event in event_handler:
126118
logger.info(f"{event}")
127-
119+
128120
if isinstance(event, ThreadMessageCreated):
129121
step_counter += 1
130122

@@ -135,7 +127,7 @@ async def handle_assistant_stream(
135127
stream_name=f"textDelta{step_counter}"
136128
)
137129
)
138-
time.sleep(0.25) # Give the client time to render the message
130+
time.sleep(0.25) # Give the client time to render the message
139131

140132
if isinstance(event, ThreadMessageDelta):
141133
logger.info(f"Sending delta with name textDelta{step_counter}")
@@ -144,108 +136,56 @@ async def handle_assistant_stream(
144136
event.data.delta.content[0].text.value
145137
)
146138

139+
147140
if isinstance(event, ThreadRunStepCreated) and event.data.type == "tool_calls":
148141
yield sse_format(
149142
f"toolCallCreated",
150143
templates.get_template('components/assistant-step.html').render(
151-
step_type='toolCall',
152-
stream_name=f'toolDelta{step_counter}'
144+
step_type='toolCall', stream_name=f'toolDelta{step_counter}'
153145
)
154146
)
155147

156-
if isinstance(event, ThreadRunStepDelta) and event.data.delta.step_details.type == "tool_calls":
148+
if isinstance(event, ThreadRunStepDelta) and event.data.type == "tool_calls":
157149
if event.data.delta.step_details.tool_calls[0].function.name:
158150
yield sse_format(
159151
f"toolDelta{step_counter}",
160-
event.data.delta.step_details.tool_calls[0].function.name + "<br>"
152+
event.data.delta.step_details.tool_calls[0].function.name + "\n"
161153
)
162154
elif event.data.delta.step_details.tool_calls[0].function.arguments:
163155
yield sse_format(
164156
f"toolDelta{step_counter}",
165157
event.data.delta.step_details.tool_calls[0].function.arguments
166158
)
167159

168-
# If the assistant run requires an action (a tool call), break and handle it
169160
if isinstance(event, ThreadRunRequiresAction):
170161
required_action = event.data.required_action
171-
run_requires_action_event = event
172-
if required_action.submit_tool_outputs:
162+
if required_action and required_action.submit_tool_outputs:
163+
# Exit the for loop and context manager
173164
break
174165

175166
if isinstance(event, ThreadRunCompleted):
176167
yield sse_format("endStream", "DONE")
177-
178-
# At the end (or break) of this async generator, we yield a final "metadata" object
179-
yield {
180-
"type": "metadata",
181-
"required_action": required_action,
182-
"step_counter": step_counter,
183-
"run_requires_action_event": run_requires_action_event
184-
}
185-
186-
async def event_generator():
187-
"""
188-
Main generator for SSE events. We call our helper function to handle the assistant
189-
stream, and if the assistant requests a tool call, we do it and then re-run the stream.
190-
"""
191-
step_counter = 0
192-
# First run of the assistant stream
193-
initial_manager = client.beta.threads.runs.stream(
194-
assistant_id=assistant_id,
195-
thread_id=thread_id
196-
)
197-
198-
# We'll re-run the loop if needed for tool calls
199-
stream_manager = initial_manager
200-
while True:
201-
async for event in handle_assistant_stream(templates, logger, stream_manager, step_counter):
202-
# Detect the special "metadata" event at the end of the generator
203-
if isinstance(event, dict) and event.get("type") == "metadata":
204-
required_action: RequiredAction | None = event["required_action"]
205-
step_counter: int = event["step_counter"]
206-
run_requires_action_event: ThreadRunRequiresAction | None = event["run_requires_action_event"]
207-
208-
# If the assistant still needs a tool call, do it and then re-stream
209-
if required_action and required_action.submit_tool_outputs:
210-
for tool_call in required_action.submit_tool_outputs.tool_calls:
211-
yield (
212-
f"event: toolCallCreated\n"
213-
f"data: {templates.get_template('components/assistant-step.html').render(
214-
step_type='toolCall', stream_name=f'toolDelta{step_counter}'
215-
).replace('\n', '')}\n\n"
216-
)
217-
218-
if tool_call.type == "function" and tool_call.function.name == "get_weather":
219-
try:
220-
args = json.loads(tool_call.function.arguments)
221-
location = args.get("location", "Unknown")
222-
except Exception as err:
223-
logger.error(f"Failed to parse function arguments: {err}")
224-
location = "Unknown"
225-
226-
weather_output = get_weather(location)
227-
logger.info(f"Weather output: {weather_output}")
228-
229-
data_for_tool = {
230-
"tool_outputs": weather_output,
231-
"runId": event.data.id,
232-
}
233-
234-
# Afterwards, create a fresh stream_manager for the next iteration
235-
new_stream_manager: AsyncAssistantStreamManager = await post_tool_outputs(
236-
client,
237-
data_for_tool,
238-
thread_id
239-
)
240-
stream_manager = new_stream_manager
241-
# proceed to rerun the loop
242-
break
243-
else:
244-
# No more tool calls needed; we're done streaming
245-
return
246-
else:
247-
# Normal SSE events: yield them to the client
248-
yield event
168+
169+
if required_action and required_action.submit_tool_outputs:
170+
# Get the weather
171+
for tool_call in required_action.submit_tool_outputs.tool_calls:
172+
try:
173+
args = json.loads(tool_call.function.arguments)
174+
location = args.get("location", "Unknown")
175+
except Exception as err:
176+
logger.error(f"Failed to parse function arguments: {err}")
177+
location = "Unknown"
178+
179+
weather_output = get_weather(location)
180+
logger.info(f"Weather output: {weather_output}")
181+
182+
data_for_tool = {
183+
"tool_outputs": weather_output,
184+
"runId": event.data.id,
185+
}
186+
stream_manager: AsyncAssistantStreamManager = await post_tool_outputs(client, data_for_tool, thread_id)
187+
188+
# We here need to run the whole stream management loop again
249189

250190
return StreamingResponse(
251191
event_generator(),

0 commit comments

Comments
 (0)