Skip to content

Commit 1dbcbe3

Browse files
Merge pull request #148 from neo4j-labs/Chatbot
Chatbot Chat history component updated
2 parents 3179a6d + 89d995d commit 1dbcbe3

File tree

3 files changed

+283
-21
lines changed

3 files changed

+283
-21
lines changed

backend/score.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ async def chat_bot(uri=Form(None),
173173
userName=Form(None),
174174
password=Form(None),
175175
question=Form(None),
176-
model=Form(None)):
177-
result = await asyncio.to_thread(QA_RAG,uri=uri,userName=userName,password=password,model_version=model,question=question)
176+
session_id=Form(None)):
177+
result = await asyncio.to_thread(QA_RAG,uri=uri,userName=userName,password=password,question=question,session_id=session_id)
178178
return result
179179

180180
@app.post("/connect")

backend/src/QA_integration.py

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,58 +7,95 @@
77
from langchain_openai import ChatOpenAI
88
from langchain_openai import OpenAIEmbeddings
99
import logging
10+
from langchain_community.chat_message_histories import Neo4jChatMessageHistory
11+
import asyncio
1012
load_dotenv()
1113

1214
openai_api_key = os.environ.get('OPENAI_API_KEY')
15+
model_version='gpt-4-0125-preview'
1316

1417
def vector_embed_results(qa,question):
1518
vector_res={}
1619
try:
17-
# question ="What do you know about machine learning"
1820
result = qa({"query": question})
19-
vector_res['result']=result["result"]
21+
vector_res['result']=result.get("result")
2022
list_source_docs=[]
2123
for i in result["source_documents"]:
2224
list_source_docs.append(i.metadata['source'])
2325
vector_res['source']=list_source_docs
2426
except Exception as e:
2527
error_message = str(e)
2628
logging.exception(f'Exception in vector embedding in QA component:{error_message}')
27-
raise Exception(error_message)
29+
# raise Exception(error_message)
2830

2931
return vector_res
3032

31-
def cypher_results(graph,question,model_version):
33+
def cypher_results(graph,question):
3234
cypher_res={}
3335
try:
3436
graph.refresh_schema()
3537
cypher_chain = GraphCypherQAChain.from_llm(
3638
graph=graph,
37-
# cypher_llm=ChatOpenAI(temperature=0, model="gpt-4"),
3839
cypher_llm=ChatOpenAI(temperature=0, model=model_version),
3940
qa_llm=ChatOpenAI(temperature=0, model=model_version),
4041
validate_cypher=True, # Validate relationship directions
4142
verbose=True,
4243
top_k=2
4344
)
44-
45-
cypher_res=cypher_chain.invoke({"query": question})
45+
try:
46+
cypher_res=cypher_chain.invoke({"query": question})
47+
except:
48+
cypher_res={}
4649

4750
except Exception as e:
4851
error_message = str(e)
4952
logging.exception(f'Exception in CypherQAChain in QA component:{error_message}')
50-
raise Exception(error_message)
53+
# raise Exception(error_message)
5154

5255
return cypher_res
5356

57+
def save_chat_history(uri,userName,password,session_id,user_message,ai_message):
58+
try:
59+
history = Neo4jChatMessageHistory(
60+
url=uri,
61+
username=userName,
62+
password=password,
63+
session_id=session_id
64+
)
65+
history.add_user_message(user_message)
66+
history.add_ai_message(ai_message)
67+
logging.info(f'Successfully saved chat history')
68+
except Exception as e:
69+
error_message = str(e)
70+
logging.exception(f'Exception in saving chat history:{error_message}')
71+
# raise Exception(error_message)
72+
5473

74+
def get_chat_history(llm,uri,userName,password,session_id):
75+
try:
76+
history = Neo4jChatMessageHistory(
77+
url=uri,
78+
username=userName,
79+
password=password,
80+
session_id=session_id
81+
)
82+
chat_history=history.messages
5583

56-
def QA_RAG(uri,userName,password,model_version,question):
84+
if len(chat_history)==0:
85+
return ""
86+
condense_template = f"""Given the following earlier conversation , Summarise the chat history.Make sure to include all the relevant information.
87+
Chat History:
88+
{chat_history}"""
89+
chat_summary=llm.predict(condense_template)
90+
return chat_summary
91+
except Exception as e:
92+
error_message = str(e)
93+
logging.exception(f'Exception in retrieving chat history:{error_message}')
94+
# raise Exception(error_message)
95+
return ''
96+
97+
def QA_RAG(uri,userName,password,question,session_id):
5798
try:
58-
if model_version=='OpenAI GPT 3.5':
59-
model_version='gpt-3.5-turbo'
60-
elif model_version=='OpenAI GPT 4':
61-
model_version='gpt-4-0125-preview'
6299
retrieval_query="""
63100
MATCH (node)-[:PART_OF]->(d:Document)
64101
WITH d, apoc.text.join(collect(node.text),"\n----\n") as text, avg(score) as score
@@ -77,7 +114,7 @@ def QA_RAG(uri,userName,password,model_version,question):
77114
llm = ChatOpenAI(model= model_version, temperature=0)
78115

79116
qa = RetrievalQA.from_chain_type(
80-
llm=llm, chain_type="stuff", retriever=neo_db.as_retriever(search_kwargs={"score_threshold": 0.5}), return_source_documents=True
117+
llm=llm, chain_type="stuff", retriever=neo_db.as_retriever(search_kwargs={'k': 3,"score_threshold": 0.5}), return_source_documents=True
81118
)
82119

83120
graph = Neo4jGraph(
@@ -86,24 +123,39 @@ def QA_RAG(uri,userName,password,model_version,question):
86123
password=password
87124
)
88125
vector_res=vector_embed_results(qa,question)
126+
print('Response from Vector embeddings')
89127
print(vector_res)
90-
cypher_res= cypher_results(graph,question,model_version)
128+
cypher_res= cypher_results(graph,question)
129+
print('Response from CypherQAChain')
91130
print(cypher_res)
131+
132+
chat_summary=get_chat_history(llm,uri,userName,password,session_id)
133+
92134
final_prompt = f"""You are a helpful question-answering agent. Your task is to analyze
93135
and synthesize information from two sources: the top result from a similarity search
94-
(unstructured information) and relevant data from a graph database (structured information).
136+
(unstructured information) and relevant data from a graph database (structured information).
137+
If structured information fails to find an answer then use the answer from unstructured information
138+
and vice versa. I only want a straightforward answer without mentioning from which source you got the answer. You are also receiving
139+
a chat history of the earlier conversation. You should be able to understand the context from the chat history and answer the question.
95140
Given the user's query: {question}, provide a meaningful and efficient answer based
96141
on the insights derived from the following data:
142+
chat_summary:{chat_summary}
97143
Structured information: {cypher_res.get('result','')}.
98144
Unstructured information: {vector_res.get('result','')}.
99145
100-
If structured information fails to find an answer then use the answer from unstructured information and vice versa. I only want a straightforward answer without mentioning from which source you got the answer.
101146
"""
102147
print(final_prompt)
103148
response = llm.predict(final_prompt)
104-
res={"message":response,"user":"chatbot"}
149+
ai_message=response
150+
user_message=question
151+
save_chat_history(uri,userName,password,session_id,user_message,ai_message)
152+
153+
res={"session_id":session_id,"message":response,"user":"chatbot"}
105154
return res
106155
except Exception as e:
107156
error_message = str(e)
108157
logging.exception(f'Exception in in QA component:{error_message}')
109-
raise Exception(error_message)
158+
# raise Exception(error_message)
159+
return {"session_id":session_id,"message":"Something went wrong","user":"chatbot"}
160+
161+

backend/src/QA_optimization.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import os
2+
from langchain_community.vectorstores.neo4j_vector import Neo4jVector
3+
from langchain.chains import GraphCypherQAChain
4+
from langchain.graphs import Neo4jGraph
5+
from langchain.chains import RetrievalQA
6+
from langchain_openai import ChatOpenAI
7+
from langchain_openai import OpenAIEmbeddings
8+
import logging
9+
from langchain_community.chat_message_histories import Neo4jChatMessageHistory
10+
import asyncio
11+
from datetime import datetime
12+
from dotenv import load_dotenv
13+
load_dotenv()
14+
15+
# openai_api_key = os.environ.get('OPENAI_API_KEY')
16+
# model_version='gpt-4-0125-preview'
17+
18+
class ParallelComponent:
19+
20+
def __init__(self, uri, userName, password, question, session_id):
21+
self.uri = uri
22+
self.userName = userName
23+
self.password = password
24+
self.question = question
25+
self.session_id = session_id
26+
self.model_version='gpt-4-0125-preview'
27+
self.llm = ChatOpenAI(model= self.model_version, temperature=0)
28+
29+
# async def execute(self):
30+
# tasks = []
31+
32+
# tasks.append(asyncio.create_task(self._vector_embed_results()))
33+
# tasks.append(asyncio.create_task(self._cypher_results()))
34+
# tasks.append(asyncio.create_task(self._get_chat_history()))
35+
36+
# return await asyncio.gather(*tasks)
37+
async def execute(self):
38+
tasks = [
39+
self._vector_embed_results(),
40+
self._cypher_results(),
41+
self._get_chat_history()
42+
]
43+
return await asyncio.gather(*tasks)
44+
45+
async def _vector_embed_results(self):
46+
t=datetime.now()
47+
print("Vector embeddings start time",t)
48+
retrieval_query="""
49+
MATCH (node)-[:PART_OF]->(d:Document)
50+
WITH d, apoc.text.join(collect(node.text),"\n----\n") as text, avg(score) as score
51+
RETURN text, score, {source: COALESCE(CASE WHEN d.url CONTAINS "None" THEN d.fileName ELSE d.url END, d.fileName)} as metadata
52+
"""
53+
vector_res={}
54+
try:
55+
neo_db=Neo4jVector.from_existing_index(
56+
embedding=OpenAIEmbeddings(),
57+
url=self.uri,
58+
username=self.userName,
59+
password=self.password,
60+
database="neo4j",
61+
index_name="vector",
62+
retrieval_query=retrieval_query,
63+
)
64+
# llm = ChatOpenAI(model= model_version, temperature=0)
65+
66+
qa = RetrievalQA.from_chain_type(
67+
llm=self.llm, chain_type="stuff", retriever=neo_db.as_retriever(search_kwargs={'k': 3,"score_threshold": 0.5}), return_source_documents=True
68+
)
69+
70+
result = qa({"query": self.question})
71+
vector_res['result']=result.get("result")
72+
list_source_docs=[]
73+
for i in result["source_documents"]:
74+
list_source_docs.append(i.metadata['source'])
75+
vector_res['source']=list_source_docs
76+
except Exception as e:
77+
error_message = str(e)
78+
logging.exception(f'Exception in vector embedding in QA component:{error_message}')
79+
# raise Exception(error_message)
80+
print("Vector embeddings duration time",datetime.now()-t)
81+
return vector_res
82+
83+
84+
async def _cypher_results(self):
85+
try:
86+
t=datetime.now()
87+
print("Cypher QA start time",t)
88+
cypher_res={}
89+
graph = Neo4jGraph(
90+
url=self.uri,
91+
username=self.userName,
92+
password=self.password
93+
)
94+
95+
96+
graph.refresh_schema()
97+
cypher_chain = GraphCypherQAChain.from_llm(
98+
graph=graph,
99+
cypher_llm=ChatOpenAI(temperature=0, model=self.model_version),
100+
qa_llm=ChatOpenAI(temperature=0, model=self.model_version),
101+
validate_cypher=True, # Validate relationship directions
102+
verbose=True,
103+
top_k=2
104+
)
105+
try:
106+
cypher_res=cypher_chain.invoke({"query": question})
107+
except:
108+
cypher_res={}
109+
110+
except Exception as e:
111+
error_message = str(e)
112+
logging.exception(f'Exception in CypherQAChain in QA component:{error_message}')
113+
# raise Exception(error_message)
114+
print("Cypher QA duration",datetime.now()-t)
115+
return cypher_res
116+
117+
118+
119+
async def _get_chat_history(self):
120+
try:
121+
t=datetime.now()
122+
print("Get chat history start time:",t)
123+
history = Neo4jChatMessageHistory(
124+
url=self.uri,
125+
username=self.userName,
126+
password=self.password,
127+
session_id=self.session_id
128+
)
129+
chat_history=history.messages
130+
131+
if len(chat_history)==0:
132+
return {"result":""}
133+
condense_template = f"""Given the following earlier conversation , Summarise the chat history.Make sure to include all the relevant information.
134+
Chat History:
135+
{chat_history}"""
136+
chat_summary=self.llm.predict(condense_template)
137+
print("Get chat history duration time:",datetime.now()-t)
138+
return {"result":chat_summary}
139+
except Exception as e:
140+
error_message = str(e)
141+
logging.exception(f'Exception in retrieving chat history:{error_message}')
142+
# raise Exception(error_message)
143+
return {"result":''}
144+
145+
async def final_prompt(self,chat_summary,cypher_res,vector_res):
146+
t=datetime.now()
147+
print('Final prompt start time:',t)
148+
final_prompt = f"""You are a helpful question-answering agent. Your task is to analyze
149+
and synthesize information from two sources: the top result from a similarity search
150+
(unstructured information) and relevant data from a graph database (structured information).
151+
If structured information fails to find an answer then use the answer from unstructured information
152+
and vice versa. I only want a straightforward answer without mentioning from which source you got the answer. You are also receiving
153+
a chat history of the earlier conversation. You should be able to understand the context from the chat history and answer the question.
154+
Given the user's query: {self.question}, provide a meaningful and efficient answer based
155+
on the insights derived from the following data:
156+
chat_summary:{chat_summary}
157+
Structured information: {cypher_res}.
158+
Unstructured information: {vector_res}.
159+
160+
"""
161+
print(final_prompt)
162+
response = self.llm.predict(final_prompt)
163+
ai_message=response
164+
user_message=question
165+
print('Final prompt duration',datetime.now()-t)
166+
return ai_message,user_message
167+
168+
169+
async def _save_chat_history(self,ai_message,user_message):
170+
try:
171+
history = Neo4jChatMessageHistory(
172+
url=self.uri,
173+
username=self.userName,
174+
password=self.password,
175+
session_id=self.session_id
176+
)
177+
history.add_user_message(user_message)
178+
history.add_ai_message(ai_message)
179+
logging.info(f'Successfully saved chat history')
180+
except Exception as e:
181+
error_message = str(e)
182+
logging.exception(f'Exception in saving chat history:{error_message}')
183+
raise Exception(error_message)
184+
185+
# Usage example:
186+
187+
uri=os.environ.get('NEO4J_URI')
188+
userName=os.environ.get('NEO4J_USERNAME')
189+
password=os.environ.get('NEO4J_PASSWORD')
190+
question='Do you know my name?'
191+
session_id=2
192+
193+
async def main(uri,userName,password,question,session_id):
194+
t=datetime.now()
195+
parallel_component = ParallelComponent(uri, userName, password, question, session_id)
196+
f_results=await parallel_component.execute()
197+
print(f_results)
198+
f_vector_result=f_results[0]['result']
199+
f_cypher_result=f_results[1].get('result','')
200+
f_chat_summary=f_results[2]['result']
201+
print(f_vector_result)
202+
print(f_cypher_result)
203+
print(f_chat_summary)
204+
ai_message,user_message=await parallel_component.final_prompt(f_chat_summary,f_cypher_result,f_vector_result)
205+
# print(asyncio.gather(asyncio.create_taskparallel_component.final_prompt(f_chat_summary,f_cypher_result,f_vector_result)))
206+
await parallel_component._save_chat_history(ai_message,user_message)
207+
print("Total Time taken:",datetime.now()-t)
208+
print("Response from AI:",ai_message)
209+
# Run with an event loop
210+
asyncio.run(main(uri,userName,password,question,session_id))

0 commit comments

Comments
 (0)