|
| 1 | +from fastapi import Body |
| 2 | +from fastapi import Depends |
| 3 | +from fastapi.encoders import jsonable_encoder |
| 4 | +from fastapi.responses import JSONResponse |
| 5 | +from sqlalchemy.ext.asyncio import AsyncSession |
| 6 | + |
| 7 | +from src.fai.api_models.chat import ChatCompletionRequest |
| 8 | +from src.fai.app import fai_app |
| 9 | +from src.fai.dependencies import get_db |
| 10 | +from src.fai.utils.chat.get_base_system_prompt import get_base_system_prompt |
| 11 | +from src.fai.utils.chat.run_rag_on_query import run_rag_on_query |
| 12 | +from src.settings import LOGGER |
| 13 | +from src.settings import anthropic_client |
| 14 | + |
| 15 | + |
| 16 | +@fai_app.post("/chat/{domain}") |
| 17 | +async def chat( |
| 18 | + domain: str, |
| 19 | + body: ChatCompletionRequest = Body(...), |
| 20 | + db: AsyncSession = Depends(get_db), |
| 21 | +) -> JSONResponse: |
| 22 | + LOGGER.info(f"Chatting for domain {domain}") |
| 23 | + try: |
| 24 | + messages = [message.to_dict() for message in body.messages] |
| 25 | + last_user_message = body.messages[-1] if len(body.messages) > 0 else None |
| 26 | + if last_user_message: |
| 27 | + query = last_user_message.content |
| 28 | + documents = run_rag_on_query(query, domain) |
| 29 | + else: |
| 30 | + documents = [] |
| 31 | + |
| 32 | + if body.system_prompt: |
| 33 | + system_prompt = body.system_prompt |
| 34 | + else: |
| 35 | + system_prompt = get_base_system_prompt(domain, "\n\n".join(documents)) |
| 36 | + |
| 37 | + if body.model: |
| 38 | + model = body.model |
| 39 | + else: |
| 40 | + model = "claude-4-sonnet-20250514" |
| 41 | + |
| 42 | + if model == "claude-4-sonnet-20250514": |
| 43 | + response = anthropic_client.messages.create( |
| 44 | + system=system_prompt, |
| 45 | + model=model, |
| 46 | + messages=messages, |
| 47 | + max_tokens=1000, |
| 48 | + ) |
| 49 | + response_content = response.content |
| 50 | + output = [] |
| 51 | + for content_turn in response_content: |
| 52 | + if content_turn.type == "text": |
| 53 | + output.append({"type": "text", "text": content_turn.text}) |
| 54 | + elif content_turn.type == "tool_use": |
| 55 | + output.append({"type": "tool_use", "input": content_turn.input}) |
| 56 | + elif content_turn.type == "tool_result": |
| 57 | + output.append({"type": "thinking", "thinking": content_turn.thinking}) |
| 58 | + else: |
| 59 | + raise ValueError(f"Model {model} not supported") |
| 60 | + |
| 61 | + return JSONResponse(content=jsonable_encoder(output)) |
| 62 | + except Exception as e: |
| 63 | + LOGGER.exception(f"Failed to chat for domain {domain}") |
| 64 | + return JSONResponse(status_code=500, content={"detail": str(e)}) |
0 commit comments