Skip to content

Commit fa1eed5

Browse files
Strem the tool calls as a code message
1 parent fb0b274 commit fa1eed5

File tree

3 files changed

+135
-32
lines changed

3 files changed

+135
-32
lines changed

routers/chat.py

Lines changed: 114 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
import logging
22
import time
3-
from typing import Any
3+
from typing import Any, AsyncGenerator
44
from fastapi.templating import Jinja2Templates
55
from fastapi import APIRouter, Form, Depends, Request
66
from fastapi.responses import StreamingResponse, HTMLResponse
77
from openai import AsyncOpenAI
88
from openai.resources.beta.threads.runs.runs import AsyncAssistantStreamManager
9-
from openai.types.beta.assistant_stream_event import ThreadMessageCreated, ThreadMessageDelta, ThreadRunCompleted, ThreadRunRequiresAction
9+
from openai.types.beta.assistant_stream_event import (
10+
ThreadMessageCreated, ThreadMessageDelta, ThreadRunCompleted,
11+
ThreadRunRequiresAction, ThreadRunStepCreated, ThreadRunStepDelta
12+
)
13+
from openai.types.beta.threads.run import RequiredAction
1014
from fastapi.responses import StreamingResponse
1115
from fastapi import APIRouter, Depends, Form, HTTPException
1216
from pydantic import BaseModel
1317
import json
1418

15-
# Import our get_weather method
1619
from utils.weather import get_weather
20+
from utils.sse import sse_format
1721

1822
logger: logging.Logger = logging.getLogger("uvicorn.error")
1923
logger.setLevel(logging.DEBUG)
@@ -95,41 +99,113 @@ async def stream_response(
9599
assistant_id: str,
96100
thread_id: str,
97101
client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())
98-
) -> StreamingResponse:
99-
100-
# Create a generator to stream the response from the assistant
101-
async def event_generator():
102-
step_counter: int = 0
103-
stream_manager: AsyncAssistantStreamManager = client.beta.threads.runs.stream(
104-
assistant_id=assistant_id,
105-
thread_id=thread_id
106-
)
102+
) -> StreamingResponse:
103+
"""
104+
Streams the assistant response via Server-Sent Events (SSE). If the assistant requires
105+
a tool call, we capture that action, invoke the tool, and then re-run the stream
106+
until completion. This is done in a DRY way by extracting the streaming logic
107+
into a helper function.
108+
"""
109+
110+
async def handle_assistant_stream(
111+
templates: Jinja2Templates,
112+
logger: logging.Logger,
113+
stream_manager: AsyncAssistantStreamManager,
114+
start_step_count: int = 0
115+
) -> AsyncGenerator:
116+
"""
117+
Async generator to yield SSE events.
118+
We yield a final 'metadata' dictionary event once we're done.
119+
"""
120+
step_counter: int = start_step_count
121+
required_action: RequiredAction | None = None
122+
run_requires_action_event: ThreadRunRequiresAction | None = None
107123

108124
async with stream_manager as event_handler:
109125
async for event in event_handler:
110126
logger.info(f"{event}")
111-
127+
112128
if isinstance(event, ThreadMessageCreated):
113129
step_counter += 1
114130

115-
yield (
116-
f"event: messageCreated\n"
117-
f"data: {templates.get_template("components/assistant-step.html").render(
118-
step_type=f"assistantMessage",
131+
yield sse_format(
132+
"messageCreated",
133+
templates.get_template("components/assistant-step.html").render(
134+
step_type="assistantMessage",
119135
stream_name=f"textDelta{step_counter}"
120-
).replace("\n", "")}\n\n"
136+
)
121137
)
122-
time.sleep(0.25) # Give the client time to render the message
138+
time.sleep(0.25) # Give the client time to render the message
123139

124140
if isinstance(event, ThreadMessageDelta):
125141
logger.info(f"Sending delta with name textDelta{step_counter}")
126-
yield (
127-
f"event: textDelta{step_counter}\n"
128-
f"data: {event.data.delta.content[0].text.value}\n\n"
142+
yield sse_format(
143+
f"textDelta{step_counter}",
144+
event.data.delta.content[0].text.value
145+
)
146+
147+
if isinstance(event, ThreadRunStepCreated) and event.data.type == "tool_calls":
148+
yield sse_format(
149+
f"toolCallCreated",
150+
templates.get_template('components/assistant-step.html').render(
151+
step_type='toolCall',
152+
stream_name=f'toolDelta{step_counter}'
153+
)
129154
)
130155

156+
if isinstance(event, ThreadRunStepDelta) and event.data.delta.step_details.type == "tool_calls":
157+
if event.data.delta.step_details.tool_calls[0].function.name:
158+
yield sse_format(
159+
f"toolDelta{step_counter}",
160+
event.data.delta.step_details.tool_calls[0].function.name + "<br>"
161+
)
162+
elif event.data.delta.step_details.tool_calls[0].function.arguments:
163+
yield sse_format(
164+
f"toolDelta{step_counter}",
165+
event.data.delta.step_details.tool_calls[0].function.arguments
166+
)
167+
168+
# If the assistant run requires an action (a tool call), break and handle it
131169
if isinstance(event, ThreadRunRequiresAction):
132170
required_action = event.data.required_action
171+
run_requires_action_event = event
172+
if required_action.submit_tool_outputs:
173+
break
174+
175+
if isinstance(event, ThreadRunCompleted):
176+
yield sse_format("endStream", "DONE")
177+
178+
# At the end (or break) of this async generator, we yield a final "metadata" object
179+
yield {
180+
"type": "metadata",
181+
"required_action": required_action,
182+
"step_counter": step_counter,
183+
"run_requires_action_event": run_requires_action_event
184+
}
185+
186+
async def event_generator():
187+
"""
188+
Main generator for SSE events. We call our helper function to handle the assistant
189+
stream, and if the assistant requests a tool call, we do it and then re-run the stream.
190+
"""
191+
step_counter = 0
192+
# First run of the assistant stream
193+
initial_manager = client.beta.threads.runs.stream(
194+
assistant_id=assistant_id,
195+
thread_id=thread_id
196+
)
197+
198+
# We'll re-run the loop if needed for tool calls
199+
stream_manager = initial_manager
200+
while True:
201+
async for event in handle_assistant_stream(templates, logger, stream_manager, step_counter):
202+
# Detect the special "metadata" event at the end of the generator
203+
if isinstance(event, dict) and event.get("type") == "metadata":
204+
required_action: RequiredAction | None = event["required_action"]
205+
step_counter: int = event["step_counter"]
206+
run_requires_action_event: ThreadRunRequiresAction | None = event["run_requires_action_event"]
207+
208+
# If the assistant still needs a tool call, do it and then re-stream
133209
if required_action and required_action.submit_tool_outputs:
134210
for tool_call in required_action.submit_tool_outputs.tool_calls:
135211
yield (
@@ -154,14 +230,22 @@ async def event_generator():
154230
"tool_outputs": weather_output,
155231
"runId": event.data.id,
156232
}
157-
await post_tool_outputs(client, data_for_tool, thread_id)
158-
159-
if isinstance(event, ThreadRunCompleted):
160-
yield "event: endStream\ndata: DONE\n\n"
161-
162-
# Send a done event when the stream is complete
163-
yield "event: endStream\ndata: DONE\n\n"
164-
233+
234+
# Afterwards, create a fresh stream_manager for the next iteration
235+
new_stream_manager: AsyncAssistantStreamManager = await post_tool_outputs(
236+
client,
237+
data_for_tool,
238+
thread_id
239+
)
240+
stream_manager = new_stream_manager
241+
# proceed to rerun the loop
242+
break
243+
else:
244+
# No more tool calls needed; we're done streaming
245+
return
246+
else:
247+
# Normal SSE events: yield them to the client
248+
yield event
165249

166250
return StreamingResponse(
167251
event_generator(),
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
<!-- assistant-step.html -->
2-
<div class="{{ step_type }}" sse-swap="{{ stream_name }}">
3-
</div>
2+
<div class="{{ step_type }}" sse-swap="{{ stream_name }}"></div>

utils/sse.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
def sse_format(event: str, data: str, retry: int = None) -> str:
2+
"""
3+
Helper function to format a Server-Sent Event (SSE) message.
4+
5+
Args:
6+
event: The name/type of the event.
7+
data: The data payload as a string.
8+
retry: Optional retry timeout in milliseconds.
9+
10+
Returns:
11+
A formatted SSE message string.
12+
"""
13+
output = f"event: {event}\n"
14+
if retry is not None:
15+
output += f"retry: {retry}\n"
16+
# Ensure each line of data is prefixed with "data: "
17+
for line in data.splitlines():
18+
output += f"data: {line}\n"
19+
output += "\n" # An extra newline indicates the end of the message.
20+
return output

0 commit comments

Comments
 (0)