|
1 | 1 | import logging
|
| 2 | +import time |
| 3 | +from typing import Any |
2 | 4 | from fastapi.templating import Jinja2Templates
|
3 | 5 | from fastapi import APIRouter, Form, Depends, Request
|
4 | 6 | from fastapi.responses import StreamingResponse, HTMLResponse
|
5 |
| -from openai import AsyncOpenAI, AssistantEventHandler |
| 7 | +from openai import AsyncOpenAI |
6 | 8 | from openai.resources.beta.threads.runs.runs import AsyncAssistantStreamManager
|
7 |
| -from openai.types.beta.threads.runs import RunStep, RunStepDelta |
8 |
| -from typing_extensions import override |
| 9 | +from openai.types.beta.assistant_stream_event import ThreadMessageCreated, ThreadMessageDelta, ThreadRunCompleted |
9 | 10 | from fastapi.responses import StreamingResponse
|
10 |
| -from openai import AsyncOpenAI, AssistantEventHandler |
11 | 11 | from fastapi import APIRouter, Depends, Form, HTTPException
|
12 | 12 | from pydantic import BaseModel
|
13 |
| -from typing import Any |
| 13 | + |
14 | 14 |
|
15 | 15 | logger: logging.Logger = logging.getLogger("uvicorn.error")
|
16 | 16 | logger.setLevel(logging.DEBUG)
|
17 | 17 |
|
| 18 | + |
18 | 19 | router: APIRouter = APIRouter(
|
19 | 20 | prefix="/assistants/{assistant_id}/messages/{thread_id}",
|
20 | 21 | tags=["assistants_messages"]
|
@@ -47,33 +48,6 @@ async def post_tool_outputs(client: AsyncOpenAI, data: dict, thread_id: str):
|
47 | 48 | logger.error(f"Error submitting tool outputs: {e}")
|
48 | 49 | raise HTTPException(status_code=500, detail=str(e))
|
49 | 50 |
|
50 |
| -# TODO: Handle message created event by rendering assistant-step.html with the |
51 |
| -# event type ("assistantMessage" or "toolCall", in this case "assistantMessage") |
52 |
| -# and name ("assistantMessage" plus a counter) |
53 |
| - |
54 |
| -# Custom event handler for the assistant run stream |
55 |
| -class CustomEventHandler(AssistantEventHandler): |
56 |
| - def __init__(self): |
57 |
| - super().__init__() |
58 |
| - |
59 |
| - @override |
60 |
| - def on_tool_call_created(self, tool_call): |
61 |
| - yield f"<span class='tool-call'>Calling {tool_call.type} tool</span>\n" |
62 |
| - |
63 |
| - @override |
64 |
| - def on_tool_call_delta(self, delta, snapshot): |
65 |
| - if delta.type == 'code_interpreter': |
66 |
| - if delta.code_interpreter.input: |
67 |
| - yield f"<span class='code'>{delta.code_interpreter.input}</span>\n" |
68 |
| - if delta.code_interpreter.outputs: |
69 |
| - for output in delta.code_interpreter.outputs: |
70 |
| - if output.type == "logs": |
71 |
| - yield f"<span class='console'>{output.logs}</span>\n" |
72 |
| - if delta.type == "function": |
73 |
| - yield |
74 |
| - if delta.type == "file_search": |
75 |
| - yield |
76 |
| - |
77 | 51 |
|
78 | 52 | # Route to submit a new user message to a thread and mount a component that
|
79 | 53 | # will start an assistant run stream
|
@@ -114,22 +88,45 @@ async def stream_response(
|
114 | 88 | thread_id: str,
|
115 | 89 | client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())
|
116 | 90 | ) -> StreamingResponse:
|
| 91 | + |
117 | 92 | # Create a generator to stream the response from the assistant
|
118 | 93 | async def event_generator():
|
| 94 | + step_counter: int = 0 |
119 | 95 | stream_manager: AsyncAssistantStreamManager = client.beta.threads.runs.stream(
|
120 | 96 | assistant_id=assistant_id,
|
121 | 97 | thread_id=thread_id
|
122 | 98 | )
|
123 | 99 |
|
124 |
| - event_handler = CustomEventHandler() |
125 |
| - |
126 | 100 | async with stream_manager as event_handler:
|
127 |
| - async for text in event_handler.text_deltas: |
128 |
| - yield f"data: {text.replace('\n', ' ')}\n\n" |
| 101 | + async for event in event_handler: |
| 102 | + logger.info(f"{event}") |
| 103 | + |
| 104 | + if isinstance(event, ThreadMessageCreated): |
| 105 | + step_counter += 1 |
| 106 | + |
| 107 | + yield ( |
| 108 | + f"event: messageCreated\n" |
| 109 | + f"data: {templates.get_template("components/assistant-step.html").render( |
| 110 | + step_type=f"assistantMessage", |
| 111 | + stream_name=f"textDelta{step_counter}" |
| 112 | + ).replace("\n", "")}\n\n" |
| 113 | + ) |
| 114 | + time.sleep(0.25) # Give the client time to render the message |
| 115 | + |
| 116 | + if isinstance(event, ThreadMessageDelta): |
| 117 | + logger.info(f"Sending delta with name textDelta{step_counter}") |
| 118 | + yield ( |
| 119 | + f"event: textDelta{step_counter}\n" |
| 120 | + f"data: {event.data.delta.content[0].text.value}\n\n" |
| 121 | + ) |
| 122 | + |
| 123 | + if isinstance(event, ThreadRunCompleted): |
| 124 | + yield "event: endStream\ndata: DONE\n\n" |
129 | 125 |
|
130 | 126 | # Send a done event when the stream is complete
|
131 |
| - yield "event: EndMessage\ndata: DONE\n\n" |
| 127 | + yield "event: endStream\ndata: DONE\n\n" |
132 | 128 |
|
| 129 | + |
133 | 130 | return StreamingResponse(
|
134 | 131 | event_generator(),
|
135 | 132 | media_type="text/event-stream",
|
|
0 commit comments