1+ import os
2+ from langchain_community .vectorstores .neo4j_vector import Neo4jVector
3+ from langchain .chains import GraphCypherQAChain
4+ from langchain .graphs import Neo4jGraph
5+ from langchain .chains import RetrievalQA
6+ from langchain_openai import ChatOpenAI
7+ from langchain_openai import OpenAIEmbeddings
8+ import logging
9+ from langchain_community .chat_message_histories import Neo4jChatMessageHistory
10+ import asyncio
11+ from datetime import datetime
12+ from dotenv import load_dotenv
13+ load_dotenv ()
14+
15+ # openai_api_key = os.environ.get('OPENAI_API_KEY')
16+ # model_version='gpt-4-0125-preview'
17+
18+ class ParallelComponent :
19+
20+ def __init__ (self , uri , userName , password , question , session_id ):
21+ self .uri = uri
22+ self .userName = userName
23+ self .password = password
24+ self .question = question
25+ self .session_id = session_id
26+ self .model_version = 'gpt-4-0125-preview'
27+ self .llm = ChatOpenAI (model = self .model_version , temperature = 0 )
28+
29+ # async def execute(self):
30+ # tasks = []
31+
32+ # tasks.append(asyncio.create_task(self._vector_embed_results()))
33+ # tasks.append(asyncio.create_task(self._cypher_results()))
34+ # tasks.append(asyncio.create_task(self._get_chat_history()))
35+
36+ # return await asyncio.gather(*tasks)
37+ async def execute (self ):
38+ tasks = [
39+ self ._vector_embed_results (),
40+ self ._cypher_results (),
41+ self ._get_chat_history ()
42+ ]
43+ return await asyncio .gather (* tasks )
44+
45+ async def _vector_embed_results (self ):
46+ t = datetime .now ()
47+ print ("Vector embeddings start time" ,t )
48+ retrieval_query = """
49+ MATCH (node)-[:PART_OF]->(d:Document)
50+ WITH d, apoc.text.join(collect(node.text),"\n ----\n ") as text, avg(score) as score
51+ RETURN text, score, {source: COALESCE(CASE WHEN d.url CONTAINS "None" THEN d.fileName ELSE d.url END, d.fileName)} as metadata
52+ """
53+ vector_res = {}
54+ try :
55+ neo_db = Neo4jVector .from_existing_index (
56+ embedding = OpenAIEmbeddings (),
57+ url = self .uri ,
58+ username = self .userName ,
59+ password = self .password ,
60+ database = "neo4j" ,
61+ index_name = "vector" ,
62+ retrieval_query = retrieval_query ,
63+ )
64+ # llm = ChatOpenAI(model= model_version, temperature=0)
65+
66+ qa = RetrievalQA .from_chain_type (
67+ llm = self .llm , chain_type = "stuff" , retriever = neo_db .as_retriever (search_kwargs = {'k' : 3 ,"score_threshold" : 0.5 }), return_source_documents = True
68+ )
69+
70+ result = qa ({"query" : self .question })
71+ vector_res ['result' ]= result .get ("result" )
72+ list_source_docs = []
73+ for i in result ["source_documents" ]:
74+ list_source_docs .append (i .metadata ['source' ])
75+ vector_res ['source' ]= list_source_docs
76+ except Exception as e :
77+ error_message = str (e )
78+ logging .exception (f'Exception in vector embedding in QA component:{ error_message } ' )
79+ # raise Exception(error_message)
80+ print ("Vector embeddings duration time" ,datetime .now ()- t )
81+ return vector_res
82+
83+
84+ async def _cypher_results (self ):
85+ try :
86+ t = datetime .now ()
87+ print ("Cypher QA start time" ,t )
88+ cypher_res = {}
89+ graph = Neo4jGraph (
90+ url = self .uri ,
91+ username = self .userName ,
92+ password = self .password
93+ )
94+
95+
96+ graph .refresh_schema ()
97+ cypher_chain = GraphCypherQAChain .from_llm (
98+ graph = graph ,
99+ cypher_llm = ChatOpenAI (temperature = 0 , model = self .model_version ),
100+ qa_llm = ChatOpenAI (temperature = 0 , model = self .model_version ),
101+ validate_cypher = True , # Validate relationship directions
102+ verbose = True ,
103+ top_k = 2
104+ )
105+ try :
106+ cypher_res = cypher_chain .invoke ({"query" : question })
107+ except :
108+ cypher_res = {}
109+
110+ except Exception as e :
111+ error_message = str (e )
112+ logging .exception (f'Exception in CypherQAChain in QA component:{ error_message } ' )
113+ # raise Exception(error_message)
114+ print ("Cypher QA duration" ,datetime .now ()- t )
115+ return cypher_res
116+
117+
118+
119+ async def _get_chat_history (self ):
120+ try :
121+ t = datetime .now ()
122+ print ("Get chat history start time:" ,t )
123+ history = Neo4jChatMessageHistory (
124+ url = self .uri ,
125+ username = self .userName ,
126+ password = self .password ,
127+ session_id = self .session_id
128+ )
129+ chat_history = history .messages
130+
131+ if len (chat_history )== 0 :
132+ return {"result" :"" }
133+ condense_template = f"""Given the following earlier conversation , Summarise the chat history.Make sure to include all the relevant information.
134+ Chat History:
135+ { chat_history } """
136+ chat_summary = self .llm .predict (condense_template )
137+ print ("Get chat history duration time:" ,datetime .now ()- t )
138+ return {"result" :chat_summary }
139+ except Exception as e :
140+ error_message = str (e )
141+ logging .exception (f'Exception in retrieving chat history:{ error_message } ' )
142+ # raise Exception(error_message)
143+ return {"result" :'' }
144+
145+ async def final_prompt (self ,chat_summary ,cypher_res ,vector_res ):
146+ t = datetime .now ()
147+ print ('Final prompt start time:' ,t )
148+ final_prompt = f"""You are a helpful question-answering agent. Your task is to analyze
149+ and synthesize information from two sources: the top result from a similarity search
150+ (unstructured information) and relevant data from a graph database (structured information).
151+ If structured information fails to find an answer then use the answer from unstructured information
152+ and vice versa. I only want a straightforward answer without mentioning from which source you got the answer. You are also receiving
153+ a chat history of the earlier conversation. You should be able to understand the context from the chat history and answer the question.
154+ Given the user's query: { self .question } , provide a meaningful and efficient answer based
155+ on the insights derived from the following data:
156+ chat_summary:{ chat_summary }
157+ Structured information: { cypher_res } .
158+ Unstructured information: { vector_res } .
159+
160+ """
161+ print (final_prompt )
162+ response = self .llm .predict (final_prompt )
163+ ai_message = response
164+ user_message = question
165+ print ('Final prompt duration' ,datetime .now ()- t )
166+ return ai_message ,user_message
167+
168+
169+ async def _save_chat_history (self ,ai_message ,user_message ):
170+ try :
171+ history = Neo4jChatMessageHistory (
172+ url = self .uri ,
173+ username = self .userName ,
174+ password = self .password ,
175+ session_id = self .session_id
176+ )
177+ history .add_user_message (user_message )
178+ history .add_ai_message (ai_message )
179+ logging .info (f'Successfully saved chat history' )
180+ except Exception as e :
181+ error_message = str (e )
182+ logging .exception (f'Exception in saving chat history:{ error_message } ' )
183+ raise Exception (error_message )
184+
185+ # Usage example:
186+
187+ uri = os .environ .get ('NEO4J_URI' )
188+ userName = os .environ .get ('NEO4J_USERNAME' )
189+ password = os .environ .get ('NEO4J_PASSWORD' )
190+ question = 'Do you know my name?'
191+ session_id = 2
192+
193+ async def main (uri ,userName ,password ,question ,session_id ):
194+ t = datetime .now ()
195+ parallel_component = ParallelComponent (uri , userName , password , question , session_id )
196+ f_results = await parallel_component .execute ()
197+ print (f_results )
198+ f_vector_result = f_results [0 ]['result' ]
199+ f_cypher_result = f_results [1 ].get ('result' ,'' )
200+ f_chat_summary = f_results [2 ]['result' ]
201+ print (f_vector_result )
202+ print (f_cypher_result )
203+ print (f_chat_summary )
204+ ai_message ,user_message = await parallel_component .final_prompt (f_chat_summary ,f_cypher_result ,f_vector_result )
205+ # print(asyncio.gather(asyncio.create_taskparallel_component.final_prompt(f_chat_summary,f_cypher_result,f_vector_result)))
206+ await parallel_component ._save_chat_history (ai_message ,user_message )
207+ print ("Total Time taken:" ,datetime .now ()- t )
208+ print ("Response from AI:" ,ai_message )
209+ # Run with an event loop
210+ asyncio .run (main (uri ,userName ,password ,question ,session_id ))
0 commit comments