|
4 | 4 | from fastapi.responses import StreamingResponse
|
5 | 5 | from openai import AsyncOpenAI, AssistantEventHandler
|
6 | 6 | from openai.resources.beta.threads.runs.runs import AsyncAssistantStreamManager
|
| 7 | +from openai.types.beta.threads.runs import RunStep, RunStepDelta |
7 | 8 | from typing_extensions import override
|
8 | 9 | from fastapi.responses import StreamingResponse
|
9 | 10 | from openai import AsyncOpenAI, AssistantEventHandler
|
10 |
| -from fastapi import APIRouter, Depends, Form |
11 |
| -from typing_extensions import override |
| 11 | +from fastapi import APIRouter, Depends, Form, HTTPException |
| 12 | +from pydantic import BaseModel |
| 13 | +from typing import Any |
12 | 14 |
|
13 | 15 | logger: logging.Logger = logging.getLogger("uvicorn.error")
|
14 | 16 | logger.setLevel(logging.DEBUG)
|
15 | 17 |
|
16 | 18 |
|
| 19 | + |
17 | 20 | router: APIRouter = APIRouter(
|
18 | 21 | prefix="/assistants/{assistant_id}/messages/{thread_id}",
|
19 | 22 | tags=["assistants_messages"]
|
|
22 | 25 | # Load Jinja2 templates
|
23 | 26 | templates = Jinja2Templates(directory="templates")
|
24 | 27 |
|
| 28 | +class ToolCallOutputs(BaseModel): |
| 29 | + tool_outputs: Any |
| 30 | + runId: str |
| 31 | + |
| 32 | +async def post_tool_outputs(client: AsyncOpenAI, data: dict, thread_id: str): |
| 33 | + try: |
| 34 | + # Parse the JSON body into the ToolCallOutputs model |
| 35 | + tool_call_outputs = ToolCallOutputs(**data) |
| 36 | + |
| 37 | + # Submit tool outputs stream |
| 38 | + stream = await client.beta.threads.runs.submit_tool_outputs_stream( |
| 39 | + thread_id, |
| 40 | + tool_call_outputs.runId, |
| 41 | + {"tool_outputs": tool_call_outputs.tool_outputs} |
| 42 | + ) |
| 43 | + |
| 44 | + # Return the stream as a response |
| 45 | + return stream.to_readable_stream() |
| 46 | + except Exception as e: |
| 47 | + logger.error(f"Error submitting tool outputs: {e}") |
| 48 | + raise HTTPException(status_code=500, detail=str(e)) |
| 49 | + |
| 50 | + |
25 | 51 | class CustomEventHandler(AssistantEventHandler):
|
26 | 52 | def __init__(self):
|
27 | 53 | super().__init__()
|
28 |
| - self.message_content = "" |
29 | 54 |
|
30 | 55 | @override
|
31 |
| - def on_text_created(self, text) -> None: |
32 |
| - print(f"\nassistant > ", end="", flush=True) |
33 |
| - |
34 |
| - @override |
35 |
| - def on_text_delta(self, delta, snapshot): |
36 |
| - print(delta.value, end="", flush=True) |
37 |
| - self.message_content += delta.value |
38 |
| - |
39 |
| - @override |
40 |
| - def on_text_done(self, text): |
41 |
| - print(f"\nassistant > done", flush=True) |
42 |
| - |
43 | 56 | def on_tool_call_created(self, tool_call):
|
44 |
| - print(f"\nassistant > {tool_call.type}\n", flush=True) |
| 57 | + yield f"<span class='tool-call'>Calling {tool_call.type} tool</span>\n" |
45 | 58 |
|
| 59 | + @override |
46 | 60 | def on_tool_call_delta(self, delta, snapshot):
|
47 | 61 | if delta.type == 'code_interpreter':
|
48 | 62 | if delta.code_interpreter.input:
|
49 |
| - print(delta.code_interpreter.input, end="", flush=True) |
| 63 | + yield f"<span class='code'>{delta.code_interpreter.input}</span>\n" |
50 | 64 | if delta.code_interpreter.outputs:
|
51 |
| - print(f"\n\noutput >", flush=True) |
52 | 65 | for output in delta.code_interpreter.outputs:
|
53 | 66 | if output.type == "logs":
|
54 |
| - print(f"\n{output.logs}", flush=True) |
| 67 | + yield f"<span class='console'>{output.logs}</span>\n" |
| 68 | + if delta.type == "function": |
| 69 | + yield |
| 70 | + if delta.type == "file_search": |
| 71 | + yield |
| 72 | + |
55 | 73 |
|
56 | 74 | # Send a new message to a thread
|
57 | 75 | @router.post("/send")
|
|
0 commit comments