Skip to content

Commit c154822

Browse files
Iterate over events rather than text deltas
1 parent fd4d64b commit c154822

File tree

4 files changed

+37
-43
lines changed

4 files changed

+37
-43
lines changed

routers/chat.py

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
import logging
2+
import time
3+
from typing import Any
24
from fastapi.templating import Jinja2Templates
35
from fastapi import APIRouter, Form, Depends, Request
46
from fastapi.responses import StreamingResponse, HTMLResponse
5-
from openai import AsyncOpenAI, AssistantEventHandler
7+
from openai import AsyncOpenAI
68
from openai.resources.beta.threads.runs.runs import AsyncAssistantStreamManager
7-
from openai.types.beta.threads.runs import RunStep, RunStepDelta
8-
from typing_extensions import override
9+
from openai.types.beta.assistant_stream_event import ThreadMessageCreated, ThreadMessageDelta, ThreadRunCompleted
910
from fastapi.responses import StreamingResponse
10-
from openai import AsyncOpenAI, AssistantEventHandler
1111
from fastapi import APIRouter, Depends, Form, HTTPException
1212
from pydantic import BaseModel
13-
from typing import Any
13+
1414

1515
logger: logging.Logger = logging.getLogger("uvicorn.error")
1616
logger.setLevel(logging.DEBUG)
1717

18+
1819
router: APIRouter = APIRouter(
1920
prefix="/assistants/{assistant_id}/messages/{thread_id}",
2021
tags=["assistants_messages"]
@@ -47,33 +48,6 @@ async def post_tool_outputs(client: AsyncOpenAI, data: dict, thread_id: str):
4748
logger.error(f"Error submitting tool outputs: {e}")
4849
raise HTTPException(status_code=500, detail=str(e))
4950

50-
# TODO: Handle message created event by rendering assistant-step.html with the
51-
# event type ("assistantMessage" or "toolCall", in this case "assistantMessage")
52-
# and name ("assistantMessage" plus a counter)
53-
54-
# Custom event handler for the assistant run stream
55-
class CustomEventHandler(AssistantEventHandler):
56-
def __init__(self):
57-
super().__init__()
58-
59-
@override
60-
def on_tool_call_created(self, tool_call):
61-
yield f"<span class='tool-call'>Calling {tool_call.type} tool</span>\n"
62-
63-
@override
64-
def on_tool_call_delta(self, delta, snapshot):
65-
if delta.type == 'code_interpreter':
66-
if delta.code_interpreter.input:
67-
yield f"<span class='code'>{delta.code_interpreter.input}</span>\n"
68-
if delta.code_interpreter.outputs:
69-
for output in delta.code_interpreter.outputs:
70-
if output.type == "logs":
71-
yield f"<span class='console'>{output.logs}</span>\n"
72-
if delta.type == "function":
73-
yield
74-
if delta.type == "file_search":
75-
yield
76-
7751

7852
# Route to submit a new user message to a thread and mount a component that
7953
# will start an assistant run stream
@@ -114,22 +88,45 @@ async def stream_response(
11488
thread_id: str,
11589
client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())
11690
) -> StreamingResponse:
91+
11792
# Create a generator to stream the response from the assistant
11893
async def event_generator():
94+
step_counter: int = 0
11995
stream_manager: AsyncAssistantStreamManager = client.beta.threads.runs.stream(
12096
assistant_id=assistant_id,
12197
thread_id=thread_id
12298
)
12399

124-
event_handler = CustomEventHandler()
125-
126100
async with stream_manager as event_handler:
127-
async for text in event_handler.text_deltas:
128-
yield f"data: {text.replace('\n', '&#10;')}\n\n"
101+
async for event in event_handler:
102+
logger.info(f"{event}")
103+
104+
if isinstance(event, ThreadMessageCreated):
105+
step_counter += 1
106+
107+
yield (
108+
f"event: messageCreated\n"
109+
f"data: {templates.get_template("components/assistant-step.html").render(
110+
step_type=f"assistantMessage",
111+
stream_name=f"textDelta{step_counter}"
112+
).replace("\n", "")}\n\n"
113+
)
114+
time.sleep(0.25) # Give the client time to render the message
115+
116+
if isinstance(event, ThreadMessageDelta):
117+
logger.info(f"Sending delta with name textDelta{step_counter}")
118+
yield (
119+
f"event: textDelta{step_counter}\n"
120+
f"data: {event.data.delta.content[0].text.value}\n\n"
121+
)
122+
123+
if isinstance(event, ThreadRunCompleted):
124+
yield "event: endStream\ndata: DONE\n\n"
129125

130126
# Send a done event when the stream is complete
131-
yield "event: EndMessage\ndata: DONE\n\n"
127+
yield "event: endStream\ndata: DONE\n\n"
132128

129+
133130
return StreamingResponse(
134131
event_generator(),
135132
media_type="text/event-stream",

templates/components/assistant-run.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
<!-- assistant-run.html -->
22
<div class="assistant-run" hx-swap="beforeend"
3+
hx-ext="sse"
34
sse-connect="/assistants/{{ assistant_id }}/messages/{{ thread_id }}/receive"
45
sse-swap="messageCreated"
56
sse-close="endStream">
Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
11
<!-- assistant-step.html -->
2-
<div class="{{ event_type }}"
3-
sse-swap="{{ event_name }}"
4-
hx-swap="beforeend">
2+
<div class="{{ step_type }}" sse-swap="{{ stream_name }}">
53
</div>
Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
11
<!-- user-message.html -->
2-
<div class="userMessage">
3-
{{ user_input }}
4-
</div>
2+
<div class="userMessage">{{ user_input }}</div>

0 commit comments

Comments
 (0)