|
1 | 1 | from __future__ import annotations |
| 2 | +from typing import Union |
2 | 3 | from app.graph_chain import graph_chain, CYPHER_GENERATION_PROMPT |
3 | 4 | from app.vector_chain import vector_chain, VECTOR_PROMPT |
4 | 5 | from app.simple_agent import simple_agent_chain |
5 | | -from fastapi import FastAPI, Request, Response |
6 | | -from fastapi.middleware.cors import CORSMiddleware |
7 | | -from starlette.middleware.base import BaseHTTPMiddleware |
| 6 | +from fastapi import FastAPI |
| 7 | +from typing import Union, Optional |
8 | 8 | from pydantic import BaseModel, Field |
9 | | -from neo4j import exceptions |
10 | | -import logging |
11 | 9 |
|
12 | 10 |
|
13 | 11 | class ApiChatPostRequest(BaseModel): |
14 | 12 | message: str = Field(..., description="The chat message to send") |
15 | | - mode: str = Field( |
16 | | - "agent", |
17 | | - description='The mode of the chat message. Current options are: "vector", "graph", "agent". Default is "agent"', |
18 | | - ) |
19 | 13 |
|
20 | 14 |
|
21 | 15 | class ApiChatPostResponse(BaseModel): |
22 | | - response: str |
23 | | - |
24 | | - |
25 | | -class Neo4jExceptionMiddleware(BaseHTTPMiddleware): |
26 | | - async def dispatch(self, request: Request, call_next): |
27 | | - try: |
28 | | - response = await call_next(request) |
29 | | - return response |
30 | | - except exceptions.AuthError as e: |
31 | | - msg = f"Neo4j Authentication Error: {e}" |
32 | | - logging.warning(msg) |
33 | | - return Response(content=msg, status_code=400, media_type="text/plain") |
34 | | - except exceptions.ServiceUnavailable as e: |
35 | | - msg = f"Neo4j Database Unavailable Error: {e}" |
36 | | - logging.warning(msg) |
37 | | - return Response(content=msg, status_code=400, media_type="text/plain") |
38 | | - except Exception as e: |
39 | | - msg = f"Neo4j Uncaught Exception: {e}" |
40 | | - logging.error(msg) |
41 | | - return Response(content=msg, status_code=400, media_type="text/plain") |
42 | | - |
43 | | - |
44 | | -# Allowed CORS origins |
45 | | -origins = [ |
46 | | - "http://127.0.0.1:8000", # Alternative localhost address |
47 | | - "http://localhost:8000", |
48 | | -] |
| 16 | + message: Optional[str] = Field(None, description="The chat message response") |
| 17 | + |
49 | 18 |
|
50 | 19 | app = FastAPI() |
51 | 20 |
|
52 | | -# Add CORS middleware to allow cross-origin requests |
53 | | -app.add_middleware( |
54 | | - CORSMiddleware, |
55 | | - allow_origins=origins, |
56 | | - allow_credentials=True, |
57 | | - allow_methods=["*"], |
58 | | - allow_headers=["*"], |
| 21 | + |
| 22 | +@app.post( |
| 23 | + "/api/chat", |
| 24 | + response_model=None, |
| 25 | + responses={"201": {"model": ApiChatPostResponse}}, |
| 26 | + tags=["chat"], |
| 27 | + description="Endpoint utilizing a simple agent to composite responses from the Vector and Graph chains interfacing with a Neo4j instance.", |
59 | 28 | ) |
60 | | -# Add Neo4j exception handling middleware |
61 | | -app.add_middleware(Neo4jExceptionMiddleware) |
| 29 | +def send_chat_message(body: ApiChatPostRequest) -> Union[None, ApiChatPostResponse]: |
| 30 | + """ |
| 31 | + Send a chat message |
| 32 | + """ |
| 33 | + |
| 34 | + question = body.message |
| 35 | + |
| 36 | + v_response = vector_chain().invoke( |
| 37 | + {"question": question}, prompt=VECTOR_PROMPT, return_only_outputs=True |
| 38 | + ) |
| 39 | + g_response = graph_chain().invoke( |
| 40 | + {"query": question}, prompt=CYPHER_GENERATION_PROMPT, return_only_outputs=True |
| 41 | + ) |
| 42 | + |
| 43 | + # Return an answer from a chain that composites both the Vector and Graph responses |
| 44 | + response = simple_agent_chain().invoke( |
| 45 | + { |
| 46 | + "question": question, |
| 47 | + "vector_result": v_response, |
| 48 | + "graph_result": g_response, |
| 49 | + } |
| 50 | + ) |
| 51 | + |
| 52 | + return f"{response}", 200 |
62 | 53 |
|
63 | 54 |
|
64 | 55 | @app.post( |
65 | | - "/api/chat", |
| 56 | + "/api/chat/vector", |
| 57 | + response_model=None, |
| 58 | + responses={"201": {"model": ApiChatPostResponse}}, |
| 59 | + tags=["chat"], |
| 60 | + description="Endpoint for utilizing only vector index for querying Neo4j instance.", |
| 61 | +) |
| 62 | +def send_chat_vector_message( |
| 63 | + body: ApiChatPostRequest, |
| 64 | +) -> Union[None, ApiChatPostResponse]: |
| 65 | + """ |
| 66 | + Send a chat message |
| 67 | + """ |
| 68 | + |
| 69 | + question = body.message |
| 70 | + |
| 71 | + response = vector_chain().invoke( |
| 72 | + {"question": question}, prompt=VECTOR_PROMPT, return_only_outputs=True |
| 73 | + ) |
| 74 | + |
| 75 | + return f"{response}", 200 |
| 76 | + |
| 77 | + |
| 78 | +@app.post( |
| 79 | + "/api/chat/graph", |
66 | 80 | response_model=None, |
67 | 81 | responses={"201": {"model": ApiChatPostResponse}}, |
68 | 82 | tags=["chat"], |
| 83 | + description="Endpoint using only Text2Cypher for querying with Neo4j instance.", |
69 | 84 | ) |
70 | | -async def send_chat_message(body: ApiChatPostRequest): |
| 85 | +def send_chat_graph_message( |
| 86 | + body: ApiChatPostRequest, |
| 87 | +) -> Union[None, ApiChatPostResponse]: |
71 | 88 | """ |
72 | 89 | Send a chat message |
73 | 90 | """ |
74 | 91 |
|
75 | 92 | question = body.message |
76 | 93 |
|
77 | | - # Simple exception check. See https://neo4j.com/docs/api/python-driver/current/api.html#errors for full set of driver exceptions |
78 | | - |
79 | | - if body.mode == "vector": |
80 | | - # Return only the Vector answer |
81 | | - v_response = vector_chain().invoke( |
82 | | - {"query": question}, prompt=VECTOR_PROMPT, return_only_outputs=True |
83 | | - ) |
84 | | - response = v_response |
85 | | - elif body.mode == "graph": |
86 | | - # Return only the Graph (text2Cypher) answer |
87 | | - g_response = graph_chain().invoke( |
88 | | - {"query": question}, |
89 | | - prompt=CYPHER_GENERATION_PROMPT, |
90 | | - return_only_outputs=True, |
91 | | - ) |
92 | | - response = g_response["result"] |
93 | | - else: |
94 | | - # Return both vector + graph answers |
95 | | - v_response = vector_chain().invoke( |
96 | | - {"query": question}, prompt=VECTOR_PROMPT, return_only_outputs=True |
97 | | - ) |
98 | | - g_response = graph_chain().invoke( |
99 | | - {"query": question}, |
100 | | - prompt=CYPHER_GENERATION_PROMPT, |
101 | | - return_only_outputs=True, |
102 | | - )["result"] |
103 | | - |
104 | | - # Synthesize a composite of both the Vector and Graph responses |
105 | | - response = simple_agent_chain().invoke( |
106 | | - { |
107 | | - "question": question, |
108 | | - "vector_result": v_response, |
109 | | - "graph_result": g_response, |
110 | | - } |
111 | | - ) |
112 | | - |
113 | | - return response, 200 |
| 94 | + response = graph_chain().invoke( |
| 95 | + {"query": question}, prompt=CYPHER_GENERATION_PROMPT, return_only_outputs=True |
| 96 | + ) |
| 97 | + |
| 98 | + return f"{response}", 200 |
0 commit comments