|
1 |
| -import os |
2 | 1 | import logging
|
3 |
| -from dotenv import load_dotenv |
4 | 2 | from fastapi.templating import Jinja2Templates
|
5 |
| -from fastapi import APIRouter, Form, HTTPException, Depends |
| 3 | +from fastapi import APIRouter, Form, Depends, Request |
6 | 4 | from fastapi.responses import StreamingResponse
|
7 | 5 | from openai import AsyncOpenAI
|
8 | 6 | from openai.resources.beta.threads.runs.runs import AsyncAssistantStreamManager
|
9 |
| -import json |
10 | 7 |
|
11 | 8 | logger: logging.Logger = logging.getLogger("uvicorn.error")
|
12 | 9 | logger.setLevel(logging.DEBUG)
|
13 | 10 |
|
14 | 11 | # Initialize the router
|
15 | 12 | router: APIRouter = APIRouter(
|
16 |
| - prefix="/assistants/{assistant_id}/messages", |
| 13 | + prefix="/assistants/{assistant_id}/messages/{thread_id}", |
17 | 14 | tags=["assistants_messages"]
|
18 | 15 | )
|
19 | 16 |
|
|
23 | 20 | # Send a new message to a thread
|
24 | 21 | @router.post("/send")
|
25 | 22 | async def post_message(
|
| 23 | + request: Request, |
| 24 | + assistant_id: str, |
| 25 | + thread_id: str, |
26 | 26 | userInput: str = Form(...),
|
27 |
| - thread_id: str = Form(), |
28 | 27 | client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())
|
29 | 28 | ) -> dict:
|
30 | 29 | # Create a new message in the thread
|
31 | 30 | await client.beta.threads.messages.create(
|
32 | 31 | thread_id=thread_id,
|
33 | 32 | role="user",
|
34 | 33 | content=userInput
|
| 34 | + |
35 | 35 | )
|
36 | 36 |
|
37 |
| - return templates.TemplateResponse("components/chat-turn.html") |
| 37 | + return templates.TemplateResponse( |
| 38 | + "components/chat-turn.html", |
| 39 | + { |
| 40 | + "request": request, |
| 41 | + "user_input": userInput, |
| 42 | + "assistant_id": assistant_id, |
| 43 | + "thread_id": thread_id |
| 44 | + } |
| 45 | + ) |
38 | 46 |
|
39 | 47 | @router.get("/receive")
|
40 | 48 | async def stream_response(
|
41 |
| - thread_id: str | None = None, |
| 49 | + assistant_id: str, |
| 50 | + thread_id: str, |
42 | 51 | client: AsyncOpenAI = Depends(lambda: AsyncOpenAI())
|
43 |
| -) -> StreamingResponse: |
44 |
| - if not thread_id: |
45 |
| - raise HTTPException(status_code=400, message="thread_id is required") |
46 |
| - |
| 52 | +) -> StreamingResponse: |
47 | 53 | # Create a generator to stream the response from the assistant
|
48 |
| - load_dotenv() |
49 | 54 | async def event_generator():
|
50 | 55 | stream: AsyncAssistantStreamManager = client.beta.threads.runs.stream(
|
51 |
| - assistant_id=os.getenv("ASSISTANT_ID"), |
| 56 | + assistant_id=assistant_id, |
52 | 57 | thread_id=thread_id
|
53 | 58 | )
|
54 | 59 | async with stream as stream_manager:
|
55 | 60 | async for text in stream_manager.text_deltas:
|
56 |
| - yield f"data: {text}" |
| 61 | + yield f"data: {text}\n\n" |
57 | 62 |
|
58 | 63 | # Send a done event when the stream is complete
|
59 |
| - yield f"event: EndMessage" |
| 64 | + yield f"event: EndMessage\n\n" |
60 | 65 |
|
61 | 66 | return StreamingResponse(
|
62 | 67 | event_generator(),
|
|
0 commit comments