Skip to content

Commit 31f11db

Browse files
authored
Dev (#3)
* Dependencies updated * CORs middleware added * Neo4j exception middleware added * Replaced deprecated LLMChain implementation * Vector chain simplified to use RetrievalQA chain
1 parent c1dbc32 commit 31f11db

File tree

7 files changed

+1244
-586
lines changed

7 files changed

+1244
-586
lines changed

README.md

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,32 @@
11
# Neo4j LangChain Starter Kit
2-
This kit provides a simple [FastAPI](https://fastapi.tiangolo.com/) backend service connected to [OpenAI](https://platform.openai.com/docs/overview) and [Neo4j](https://neo4j.com/developer/) for powering GenAI projects. The Neo4j interface leverages both [Vector Indexes](https://python.langchain.com/docs/integrations/vectorstores/neo4jvector) and [Text2Cypher](https://python.langchain.com/docs/use_cases/graph/integrations/graph_cypher_qa) chains to provide more accurate results.
32

4-
![alt text](https://res.cloudinary.com/dk0tizgdn/image/upload/v1711042573/langchain_starter_kit_sample_jgvnfb.gif "Testing Neo4j LangChain Starter Kit")
3+
This kit provides a simple [FastAPI](https://fastapi.tiangolo.com/) backend service connected to [OpenAI](https://platform.openai.com/docs/overview) and [Neo4j](https://neo4j.com/developer/) for powering GenAI projects. The Neo4j interface leverages both [Vector Indexes](https://python.langchain.com/docs/integrations/vectorstores/neo4jvector) and [Text2Cypher](https://python.langchain.com/docs/use_cases/graph/integrations/graph_cypher_qa) chains to provide more accurate results.
54

5+
![alt text](https://res.cloudinary.com/dk0tizgdn/image/upload/v1711042573/langchain_starter_kit_sample_jgvnfb.gif "Testing Neo4j LangChain Starter Kit")
66

77
## Requirements
8+
89
- [Poetry](https://python-poetry.org/) for virtual enviroment management
910
- [LangChain](https://python.langchain.com/docs/get_started/introduction)
1011
- An [OpenAI API Key](https://openai.com/blog/openai-api)
1112
- A running [local](https://neo4j.com/download/) or [cloud](https://neo4j.com/cloud/platform/aura-graph-database/) Neo4j database
1213

13-
1414
## Usage
15+
16+
Add a .env file to the root folder with the following keys and your own credentials (or these included public access only creds):
17+
18+
```
19+
NEO4J_URI=neo4j+ssc://9fcf58c6.databases.neo4j.io
20+
NEO4J_DATABASE=neo4j
21+
NEO4J_USERNAME=public
22+
NEO4J_PASSWORD=read_only
23+
OPENAI_API_KEY=<your_openai_key_here>
24+
```
25+
26+
Then run: `poetry run uvicorn app.server:app --reload --port=8000 `
27+
28+
Or add env variables at runtime:
29+
1530
```
1631
NEO4J_URI=neo4j+ssc://9fcf58c6.databases.neo4j.io \
1732
NEO4J_DATABASE=neo4j \
@@ -21,24 +36,26 @@ OPENAI_API_KEY=<add_your_openai_key_here> \
2136
poetry run uvicorn app.server:app --reload --port=8000 --log-config=log_conf.yaml
2237
```
2338

24-
*NOTE* the above Neo4j credentials are for read-only access to a hosted sample dataset. Your own OpenAI api key will be needed to run this server.
25-
26-
*NOTE* the `NEO4J_URI` value can use either the neo4j or [bolt](https://neo4j.com/docs/bolt/current/bolt/) uri scheme. For more details on which to use, see this [example](https://neo4j.com/docs/driver-manual/4.0/client-applications/#driver-configuration-examples)
39+
_NOTE_ the above Neo4j credentials are for read-only access to a hosted sample dataset. Your own OpenAI api key will be needed to run this server.
2740

41+
_NOTE_ the `NEO4J_URI` value can use either the neo4j or [bolt](https://neo4j.com/docs/bolt/current/bolt/) uri scheme. For more details on which to use, see this [example](https://neo4j.com/docs/driver-manual/4.0/client-applications/#driver-configuration-examples)
2842

2943
A FastAPI server should now be running on your local port 8000/api/chat.
3044

3145
## Custom Database Setup
46+
3247
If you would like to load your own instance with a subset of this information. Add your own OpenAI key to the Cypher code in the [edgar_import.cypher](edgar_import.cypher) file and run it in your instance's [Neo4j browser](https://neo4j.com/docs/browser-manual/current/).
3348

3449
For more information on how this load script works, see [this notebook](https://github.com/neo4j-examples/sec-edgar-notebooks/blob/main/notebooks/kg-construction/1-mvg.ipynb).
3550

36-
3751
## Docs
52+
3853
FastAPI will make endpoint information and the ability to test from a browser at http://localhost:8000/docs
3954

4055
## Testing
56+
4157
Alternatively, after the server is running, a curl command can be triggered to test the endpoint:
58+
4259
```
4360
curl --location 'http://127.0.0.1:8000/api/chat' \
4461
--header 'Content-Type: application/json' \
@@ -47,10 +64,13 @@ curl --location 'http://127.0.0.1:8000/api/chat' \
4764
```
4865

4966
## Feedback
67+
5068
Please provide feedback and report bugs as [GitHub issues](https://github.com/neo4j-examples/langchain-starter-kit/issues)
5169

5270
## Contributing
71+
5372
Want to improve this kit? See the [contributing guide](./CONTRIBUTING.md)
5473

5574
## Learn More
56-
At [Neo4j GraphAcademy](https://graphacademy.neo4j.com), we offer a wide range of courses completely free of charge, including [Neo4j & LLM Fundamentals](https://graphacademy.neo4j.com/courses/llm-fundamentals/) and [Build a Neo4j-backed Chatbot using Python](https://graphacademy.neo4j.com/courses/llm-chatbot-python/).
75+
76+
At [Neo4j GraphAcademy](https://graphacademy.neo4j.com), we offer a wide range of courses completely free of charge, including [Neo4j & LLM Fundamentals](https://graphacademy.neo4j.com/courses/llm-fundamentals/) and [Build a Neo4j-backed Chatbot using Python](https://graphacademy.neo4j.com/courses/llm-chatbot-python/).

app/graph_chain.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE
4242
)
4343

44+
4445
def graph_chain() -> Runnable:
4546

4647
NEO4J_URI = os.getenv("NEO4J_URI")
@@ -56,20 +57,20 @@ def graph_chain() -> Runnable:
5657
username=NEO4J_USERNAME,
5758
password=NEO4J_PASSWORD,
5859
database=NEO4J_DATABASE,
59-
sanitize = True
60+
sanitize=True,
6061
)
6162

6263
graph.refresh_schema()
6364

6465
# Official API doc for GraphCypherQAChain at: https://api.python.langchain.com/en/latest/chains/langchain.chains.graph_qa.base.GraphQAChain.html#
6566
graph_chain = GraphCypherQAChain.from_llm(
66-
cypher_llm = LLM,
67-
qa_llm = LLM,
68-
validate_cypher= True,
67+
cypher_llm=LLM,
68+
qa_llm=LLM,
69+
validate_cypher=True,
6970
graph=graph,
70-
verbose=True,
71-
return_intermediate_steps = True,
72-
return_direct = True,
71+
verbose=True,
72+
return_intermediate_steps=True,
73+
# return_direct = True,
7374
)
7475

75-
return graph_chain
76+
return graph_chain

app/server.py

Lines changed: 86 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,113 @@
11
from __future__ import annotations
2-
from typing import Union
32
from app.graph_chain import graph_chain, CYPHER_GENERATION_PROMPT
43
from app.vector_chain import vector_chain, VECTOR_PROMPT
54
from app.simple_agent import simple_agent_chain
6-
from fastapi import FastAPI
7-
from typing import Union, Optional
5+
from fastapi import FastAPI, Request, Response
6+
from fastapi.middleware.cors import CORSMiddleware
7+
from starlette.middleware.base import BaseHTTPMiddleware
88
from pydantic import BaseModel, Field
9+
from neo4j import exceptions
10+
import logging
911

1012

1113
class ApiChatPostRequest(BaseModel):
12-
message: str = Field(..., description='The chat message to send')
13-
mode: str = Field('agent', description='The mode of the chat message. Current options are: "vector", "graph", "agent". Default is "agent"')
14+
message: str = Field(..., description="The chat message to send")
15+
mode: str = Field(
16+
"agent",
17+
description='The mode of the chat message. Current options are: "vector", "graph", "agent". Default is "agent"',
18+
)
1419

1520

1621
class ApiChatPostResponse(BaseModel):
17-
message: Optional[str] = Field(None, description='The chat message response')
22+
response: str
23+
24+
25+
class Neo4jExceptionMiddleware(BaseHTTPMiddleware):
26+
async def dispatch(self, request: Request, call_next):
27+
try:
28+
response = await call_next(request)
29+
return response
30+
except exceptions.AuthError as e:
31+
msg = f"Neo4j Authentication Error: {e}"
32+
logging.warning(msg)
33+
return Response(content=msg, status_code=400, media_type="text/plain")
34+
except exceptions.ServiceUnavailable as e:
35+
msg = f"Neo4j Database Unavailable Error: {e}"
36+
logging.warning(msg)
37+
return Response(content=msg, status_code=400, media_type="text/plain")
38+
except Exception as e:
39+
msg = f"Neo4j Uncaught Exception: {e}"
40+
logging.error(msg)
41+
return Response(content=msg, status_code=400, media_type="text/plain")
42+
43+
44+
# Allowed CORS origins
45+
origins = [
46+
"http://127.0.0.1:8000", # Alternative localhost address
47+
"http://localhost:8000",
48+
]
49+
50+
app = FastAPI()
1851

52+
# Add CORS middleware to allow cross-origin requests
53+
app.add_middleware(
54+
CORSMiddleware,
55+
allow_origins=origins,
56+
allow_credentials=True,
57+
allow_methods=["*"],
58+
allow_headers=["*"],
59+
)
60+
# Add Neo4j exception handling middleware
61+
app.add_middleware(Neo4jExceptionMiddleware)
1962

20-
app = FastAPI()
2163

2264
@app.post(
23-
'/api/chat',
65+
"/api/chat",
2466
response_model=None,
25-
responses={'201': {'model': ApiChatPostResponse}},
26-
tags=['chat'],
67+
responses={"201": {"model": ApiChatPostResponse}},
68+
tags=["chat"],
2769
)
28-
def send_chat_message(body: ApiChatPostRequest) -> Union[None, ApiChatPostResponse]:
70+
async def send_chat_message(body: ApiChatPostRequest):
2971
"""
3072
Send a chat message
3173
"""
3274

3375
question = body.message
3476

35-
v_response = vector_chain().invoke(
36-
{"question":question},
37-
prompt = VECTOR_PROMPT,
38-
return_only_outputs = True
39-
)
40-
g_response = graph_chain().invoke(
41-
{"query":question},
42-
prompt = CYPHER_GENERATION_PROMPT,
43-
return_only_outputs = True
44-
)
45-
46-
if body.mode == 'vector':
77+
# Simple exception check. See https://neo4j.com/docs/api/python-driver/current/api.html#errors for full set of driver exceptions
78+
79+
if body.mode == "vector":
4780
# Return only the Vector answer
81+
v_response = vector_chain().invoke(
82+
{"query": question}, prompt=VECTOR_PROMPT, return_only_outputs=True
83+
)
4884
response = v_response
49-
elif body.mode == 'graph':
85+
elif body.mode == "graph":
5086
# Return only the Graph (text2Cypher) answer
51-
response = g_response
87+
g_response = graph_chain().invoke(
88+
{"query": question},
89+
prompt=CYPHER_GENERATION_PROMPT,
90+
return_only_outputs=True,
91+
)
92+
response = g_response["result"]
5293
else:
53-
# Return an answer from a chain that composites both the Vector and Graph responses
54-
response = simple_agent_chain().invoke({
55-
"question":question,
56-
"vector_result":v_response,
57-
"graph_result":g_response
58-
})["text"]
59-
60-
return f"{response}", 200
94+
# Return both vector + graph answers
95+
v_response = vector_chain().invoke(
96+
{"query": question}, prompt=VECTOR_PROMPT, return_only_outputs=True
97+
)
98+
g_response = graph_chain().invoke(
99+
{"query": question},
100+
prompt=CYPHER_GENERATION_PROMPT,
101+
return_only_outputs=True,
102+
)["result"]
103+
104+
# Synthesize a composite of both the Vector and Graph responses
105+
response = simple_agent_chain().invoke(
106+
{
107+
"question": question,
108+
"vector_result": v_response,
109+
"graph_result": g_response,
110+
}
111+
)
112+
113+
return response, 200

app/simple_agent.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from langchain.chains import LLMChain
2-
from langchain.chains.conversation.memory import ConversationBufferMemory
1+
from langchain_core.output_parsers import StrOutputParser
32
from langchain.prompts import PromptTemplate
43
from langchain.schema.runnable import Runnable
54
from langchain_openai import ChatOpenAI
5+
from langchain.chains import ConversationChain
6+
from langchain_core.prompts import PromptTemplate
67
import os
78

8-
def simple_agent_chain() -> Runnable:
99

10-
MEMORY = ConversationBufferMemory(memory_key="agent_history", input_key='question', output_key='text', return_messages=True)
10+
def simple_agent_chain() -> Runnable:
1111

1212
final_prompt = """You are a helpful question-answering agent. Your task is to analyze
1313
and synthesize information from two sources: the top result from a similarity search
@@ -19,14 +19,15 @@ def simple_agent_chain() -> Runnable:
1919
Structured information: {graph_result}.
2020
"""
2121

22-
prompt = PromptTemplate.from_template(final_prompt)
22+
prompt = PromptTemplate(
23+
input_variables=["question", "vector_result", "graph_result"],
24+
template=final_prompt,
25+
)
2326

2427
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
2528
LLM = ChatOpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)
29+
output_parser = StrOutputParser()
30+
31+
simple_agent_chain = prompt | LLM | output_parser
2632

27-
simple_agent_chain = LLMChain(
28-
prompt=prompt,
29-
llm=LLM,
30-
memory = MEMORY)
31-
32-
return simple_agent_chain
33+
return simple_agent_chain

app/vector_chain.py

Lines changed: 21 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from langchain.prompts.prompt import PromptTemplate
2-
from langchain.vectorstores.neo4j_vector import Neo4jVector
3-
from langchain.chains import RetrievalQAWithSourcesChain
2+
from langchain_community.vectorstores import Neo4jVector
3+
from langchain.chains import RetrievalQAWithSourcesChain, RetrievalQA
44
from langchain.schema.runnable import Runnable
55
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
66
import logging
@@ -28,9 +28,10 @@
2828
Assistant:"""
2929

3030
VECTOR_PROMPT = PromptTemplate(
31-
input_variables=["input","context"], template=VECTOR_PROMPT_TEMPLATE
31+
input_variables=["input", "context"], template=VECTOR_PROMPT_TEMPLATE
3232
)
3333

34+
3435
def vector_chain() -> Runnable:
3536

3637
NEO4J_URI = os.getenv("NEO4J_URI")
@@ -49,47 +50,26 @@ def vector_chain() -> Runnable:
4950

5051
# Neo4jVector API: https://api.python.langchain.com/en/latest/vectorstores/langchain_community.vectorstores.neo4j_vector.Neo4jVector.html#langchain_community.vectorstores.neo4j_vector.Neo4jVector
5152

52-
try:
53-
logging.debug(f'Attempting to retrieve existing vector index: {index_name}...')
54-
vector_store = Neo4jVector.from_existing_index(
55-
embedding=EMBEDDINGS,
56-
url=NEO4J_URI,
57-
username=NEO4J_USERNAME,
58-
password=NEO4J_PASSWORD,
59-
database=NEO4J_DATABASE,
60-
index_name=index_name,
61-
embedding_node_property=node_property_name,
62-
)
63-
logging.debug(f'Using existing index: {index_name}')
64-
except:
65-
logging.debug(f'No existing index found. Attempting to create a new vector index named {index_name}...')
66-
try:
67-
vector_store = Neo4jVector.from_existing_graph(
68-
embedding=EMBEDDINGS,
69-
url=NEO4J_URI,
70-
username=NEO4J_USERNAME,
71-
password=NEO4J_PASSWORD,
72-
database=NEO4J_DATABASE,
73-
index_name=index_name,
74-
node_label="Chunk",
75-
text_node_properties=["text"],
76-
embedding_node_property=node_property_name,
77-
)
78-
logging.debug(f'Created new index: {index_name}')
79-
except Exception as e:
80-
logging.error(f'Failed to retrieve existing or to create a Neo4jVector: {e}')
81-
82-
if vector_store is None:
83-
logging.error(f'Failed to retrieve or create a Neo4jVector. Exiting.')
84-
exit()
53+
# try:
54+
logging.debug(
55+
f"Attempting to retrieve existing vector index'{index_name}' from Neo4j instance at {NEO4J_URI}..."
56+
)
57+
vector_store = Neo4jVector.from_existing_index(
58+
embedding=EMBEDDINGS,
59+
url=NEO4J_URI,
60+
username=NEO4J_USERNAME,
61+
password=NEO4J_PASSWORD,
62+
database=NEO4J_DATABASE,
63+
index_name=index_name,
64+
embedding_node_property=node_property_name,
65+
)
66+
logging.debug(f"Using existing index: {index_name}")
8567

8668
vector_retriever = vector_store.as_retriever()
8769

88-
vector_chain = RetrievalQAWithSourcesChain.from_chain_type(
70+
vector_chain = RetrievalQA.from_chain_type(
8971
LLM,
90-
chain_type="stuff",
72+
chain_type="stuff",
9173
retriever=vector_retriever,
92-
reduce_k_below_max_tokens = True,
93-
max_tokens_limit=2000
9474
)
95-
return vector_chain
75+
return vector_chain

0 commit comments

Comments
 (0)