Skip to content

Commit 2dc83c7

Browse files
Very preliminary custome event handler
1 parent 53276d3 commit 2dc83c7

File tree

1 file changed

+40
-2
lines changed

1 file changed

+40
-2
lines changed

routers/messages.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22
from fastapi.templating import Jinja2Templates
33
from fastapi import APIRouter, Form, Depends, Request
44
from fastapi.responses import StreamingResponse
5-
from openai import AsyncOpenAI
5+
from openai import AsyncOpenAI, AssistantEventHandler
66
from openai.resources.beta.threads.runs.runs import AsyncAssistantStreamManager
7+
from typing_extensions import override
8+
from fastapi.responses import StreamingResponse
9+
from openai import AsyncOpenAI, AssistantEventHandler
10+
from fastapi import APIRouter, Depends, Form
11+
from typing_extensions import override
712

813
logger: logging.Logger = logging.getLogger("uvicorn.error")
914
logger.setLevel(logging.DEBUG)
1015

11-
# Initialize the router
16+
1217
router: APIRouter = APIRouter(
1318
prefix="/assistants/{assistant_id}/messages/{thread_id}",
1419
tags=["assistants_messages"]
@@ -17,6 +22,37 @@
1722
# Load Jinja2 templates
1823
templates = Jinja2Templates(directory="templates")
1924

25+
class CustomEventHandler(AssistantEventHandler):
26+
def __init__(self):
27+
super().__init__()
28+
self.message_content = ""
29+
30+
@override
31+
def on_text_created(self, text) -> None:
32+
print(f"\nassistant > ", end="", flush=True)
33+
34+
@override
35+
def on_text_delta(self, delta, snapshot):
36+
print(delta.value, end="", flush=True)
37+
self.message_content += delta.value
38+
39+
@override
40+
def on_text_done(self, text):
41+
print(f"\nassistant > done", flush=True)
42+
43+
def on_tool_call_created(self, tool_call):
44+
print(f"\nassistant > {tool_call.type}\n", flush=True)
45+
46+
def on_tool_call_delta(self, delta, snapshot):
47+
if delta.type == 'code_interpreter':
48+
if delta.code_interpreter.input:
49+
print(delta.code_interpreter.input, end="", flush=True)
50+
if delta.code_interpreter.outputs:
51+
print(f"\n\noutput >", flush=True)
52+
for output in delta.code_interpreter.outputs:
53+
if output.type == "logs":
54+
print(f"\n{output.logs}", flush=True)
55+
2056
# Send a new message to a thread
2157
@router.post("/send")
2258
async def post_message(
@@ -57,6 +93,8 @@ async def event_generator():
5793
thread_id=thread_id
5894
)
5995

96+
event_handler = CustomEventHandler()
97+
6098
async with stream_manager as event_handler:
6199
async for text in event_handler.text_deltas:
62100
yield f"data: {text.replace('\n', '<br>')}\n\n"

0 commit comments

Comments
 (0)