|
2 | 2 | from fastapi.templating import Jinja2Templates
|
3 | 3 | from fastapi import APIRouter, Form, Depends, Request
|
4 | 4 | from fastapi.responses import StreamingResponse
|
5 |
| -from openai import AsyncOpenAI |
| 5 | +from openai import AsyncOpenAI, AssistantEventHandler |
6 | 6 | from openai.resources.beta.threads.runs.runs import AsyncAssistantStreamManager
|
| 7 | +from typing_extensions import override |
| 8 | +from fastapi.responses import StreamingResponse |
| 9 | +from openai import AsyncOpenAI, AssistantEventHandler |
| 10 | +from fastapi import APIRouter, Depends, Form |
| 11 | +from typing_extensions import override |
7 | 12 |
|
8 | 13 | logger: logging.Logger = logging.getLogger("uvicorn.error")
|
9 | 14 | logger.setLevel(logging.DEBUG)
|
10 | 15 |
|
11 |
| -# Initialize the router |
| 16 | + |
12 | 17 | router: APIRouter = APIRouter(
|
13 | 18 | prefix="/assistants/{assistant_id}/messages/{thread_id}",
|
14 | 19 | tags=["assistants_messages"]
|
|
17 | 22 | # Load Jinja2 templates
|
18 | 23 | templates = Jinja2Templates(directory="templates")
|
19 | 24 |
|
| 25 | +class CustomEventHandler(AssistantEventHandler): |
| 26 | + def __init__(self): |
| 27 | + super().__init__() |
| 28 | + self.message_content = "" |
| 29 | + |
| 30 | + @override |
| 31 | + def on_text_created(self, text) -> None: |
| 32 | + print(f"\nassistant > ", end="", flush=True) |
| 33 | + |
| 34 | + @override |
| 35 | + def on_text_delta(self, delta, snapshot): |
| 36 | + print(delta.value, end="", flush=True) |
| 37 | + self.message_content += delta.value |
| 38 | + |
| 39 | + @override |
| 40 | + def on_text_done(self, text): |
| 41 | + print(f"\nassistant > done", flush=True) |
| 42 | + |
| 43 | + def on_tool_call_created(self, tool_call): |
| 44 | + print(f"\nassistant > {tool_call.type}\n", flush=True) |
| 45 | + |
| 46 | + def on_tool_call_delta(self, delta, snapshot): |
| 47 | + if delta.type == 'code_interpreter': |
| 48 | + if delta.code_interpreter.input: |
| 49 | + print(delta.code_interpreter.input, end="", flush=True) |
| 50 | + if delta.code_interpreter.outputs: |
| 51 | + print(f"\n\noutput >", flush=True) |
| 52 | + for output in delta.code_interpreter.outputs: |
| 53 | + if output.type == "logs": |
| 54 | + print(f"\n{output.logs}", flush=True) |
| 55 | + |
20 | 56 | # Send a new message to a thread
|
21 | 57 | @router.post("/send")
|
22 | 58 | async def post_message(
|
@@ -57,6 +93,8 @@ async def event_generator():
|
57 | 93 | thread_id=thread_id
|
58 | 94 | )
|
59 | 95 |
|
| 96 | + event_handler = CustomEventHandler() |
| 97 | + |
60 | 98 | async with stream_manager as event_handler:
|
61 | 99 | async for text in event_handler.text_deltas:
|
62 | 100 | yield f"data: {text.replace('\n', '<br>')}\n\n"
|
|
0 commit comments