77from langchain_openai import ChatOpenAI
88from langchain_openai import OpenAIEmbeddings
99from langchain_google_vertexai import VertexAIEmbeddings
10+ from langchain_google_vertexai import ChatVertexAI
11+ from langchain_google_vertexai import HarmBlockThreshold , HarmCategory
1012import logging
1113from langchain_community .chat_message_histories import Neo4jChatMessageHistory
14+ from langchain_community .embeddings .sentence_transformer import SentenceTransformerEmbeddings
15+ from src .shared .common_fn import load_embedding_model
16+ import re
17+
1218load_dotenv ()
1319
1420openai_api_key = os .environ .get ('OPENAI_API_KEY' )
15- model_version = 'gpt-4-0125-preview'
21+
22+
23+ # def get_embedding_function(embedding_model_name: str):
24+ # if embedding_model_name == "openai":
25+ # embedding_function = OpenAIEmbeddings()
26+ # dimension = 1536
27+ # logging.info(f"Embedding: Using OpenAI Embeddings , Dimension:{dimension}")
28+ # elif embedding_model_name == "vertexai":
29+ # embedding_function = VertexAIEmbeddings(
30+ # model_name="textembedding-gecko@003"
31+ # )
32+ # dimension = 768
33+ # logging.info(f"Embedding: Using Vertex AI Embeddings , Dimension:{dimension}")
34+ # else:
35+ # embedding_function = SentenceTransformerEmbeddings(
36+ # model_name="all-MiniLM-L6-v2"#, cache_folder="/embedding_model"
37+ # )
38+ # dimension = 384
39+ # logging.info(f"Embedding: Using SentenceTransformer , Dimension:{dimension}")
40+ # return embedding_function
41+
42+ def get_llm (model : str ):
43+ if model == "OpenAI GPT 3.5" :
44+ model_version = "gpt-3.5-turbo-16k"
45+ logging .info (f"Chat Model: GPT 3.5, Model Version : { model_version } " )
46+ llm = ChatOpenAI (model = model_version , temperature = 0 )
47+
48+ elif model == "Gemini Pro" :
49+ # model_version = "gemini-1.0-pro"
50+ model_version = 'gemini-1.0-pro-001'
51+ logging .info (f"Chat Model: Gemini , Model Version : { model_version } " )
52+ llm = ChatVertexAI (model_name = model_version ,
53+ # max_output_tokens=100,
54+ convert_system_message_to_human = True ,
55+ temperature = 0 ,
56+ safety_settings = {
57+ HarmCategory .HARM_CATEGORY_UNSPECIFIED : HarmBlockThreshold .BLOCK_NONE ,
58+ HarmCategory .HARM_CATEGORY_DANGEROUS_CONTENT : HarmBlockThreshold .BLOCK_NONE ,
59+ HarmCategory .HARM_CATEGORY_HATE_SPEECH : HarmBlockThreshold .BLOCK_NONE ,
60+ HarmCategory .HARM_CATEGORY_HARASSMENT : HarmBlockThreshold .BLOCK_NONE ,
61+ HarmCategory .HARM_CATEGORY_SEXUALLY_EXPLICIT : HarmBlockThreshold .BLOCK_NONE ,
62+ })
63+ elif model == "Gemini 1.5 Pro" :
64+ model_version = "gemini-1.5-pro-preview-0409"
65+ logging .info (f"Chat Model: Gemini 1.5 , Model Version : { model_version } " )
66+ llm = ChatVertexAI (model_name = model_version ,
67+ # max_output_tokens=100,
68+ convert_system_message_to_human = True ,
69+ temperature = 0 ,
70+ safety_settings = {
71+ HarmCategory .HARM_CATEGORY_UNSPECIFIED : HarmBlockThreshold .BLOCK_NONE ,
72+ HarmCategory .HARM_CATEGORY_DANGEROUS_CONTENT : HarmBlockThreshold .BLOCK_NONE ,
73+ HarmCategory .HARM_CATEGORY_HATE_SPEECH : HarmBlockThreshold .BLOCK_NONE ,
74+ HarmCategory .HARM_CATEGORY_HARASSMENT : HarmBlockThreshold .BLOCK_NONE ,
75+ HarmCategory .HARM_CATEGORY_SEXUALLY_EXPLICIT : HarmBlockThreshold .BLOCK_NONE ,
76+ })
77+ else :
78+ ## for model == "OpenAI GPT 4" or model == "Diffbot"
79+ model_version = "gpt-4-0125-preview"
80+ logging .info (f"Chat Model: GPT 4, Model Version : { model_version } " )
81+ llm = ChatOpenAI (model = model_version , temperature = 0 )
82+ return llm
1683
1784def vector_embed_results (qa ,question ):
1885 vector_res = {}
@@ -92,29 +159,52 @@ def get_chat_history(llm,uri,userName,password,session_id):
92159 error_message = str (e )
93160 logging .exception (f'Exception in retrieving chat history:{ error_message } ' )
94161 # raise Exception(error_message)
95- return ''
162+ return ''
163+
164+ def extract_and_remove_source (message ):
165+ pattern = r'\[Source: ([^\]]+)\]'
166+ match = re .search (pattern , message )
167+ if match :
168+ sources_string = match .group (1 )
169+ sources = [source .strip ().strip ("'" ) for source in sources_string .split (',' )]
170+ new_message = re .sub (pattern , '' , message ).strip ()
171+ response = {
172+ "message" : new_message ,
173+ "sources" : sources
174+ }
175+ else :
176+ response = {
177+ "message" : message ,
178+ "sources" : []
179+ }
180+ return response
96181
97- def QA_RAG (uri ,userName ,password ,question ,session_id ):
182+ def QA_RAG (uri ,model , userName ,password ,question ,session_id ):
98183 try :
99184 retrieval_query = """
100185 MATCH (node)-[:PART_OF]->(d:Document)
101186 WITH d, apoc.text.join(collect(node.text),"\n ----\n ") as text, avg(score) as score
102187 RETURN text, score, {source: COALESCE(CASE WHEN d.url CONTAINS "None" THEN d.fileName ELSE d.url END, d.fileName)} as metadata
103188 """
104189 embedding_model = os .getenv ('EMBEDDING_MODEL' )
190+ embedding_function , _ = load_embedding_model (embedding_model )
105191 neo_db = Neo4jVector .from_existing_index (
106- embedding = VertexAIEmbeddings ( model_name = embedding_model ) ,
192+ embedding = embedding_function ,
107193 url = uri ,
108194 username = userName ,
109195 password = password ,
110196 database = "neo4j" ,
111197 index_name = "vector" ,
112198 retrieval_query = retrieval_query ,
113199 )
114- llm = ChatOpenAI (model = model_version , temperature = 0 )
200+ # model = "Gemini Pro"
201+ llm = get_llm (model = model )
115202
116203 qa = RetrievalQA .from_chain_type (
117- llm = llm , chain_type = "stuff" , retriever = neo_db .as_retriever (search_kwargs = {'k' : 3 ,"score_threshold" : 0.5 }), return_source_documents = True
204+ llm = llm ,
205+ chain_type = "stuff" ,
206+ retriever = neo_db .as_retriever (search_kwargs = {'k' : 3 ,"score_threshold" : 0.5 }),
207+ return_source_documents = True
118208 )
119209
120210 vector_res = vector_embed_results (qa ,question )
@@ -133,32 +223,58 @@ def QA_RAG(uri,userName,password,question,session_id):
133223
134224 chat_summary = get_chat_history (llm ,uri ,userName ,password ,session_id )
135225
136- final_prompt = f"""You are a helpful question-answering agent. Your task is to analyze
137- and synthesize information from two sources: the top result from a similarity search
138- (unstructured information) and relevant data from a graph database (structured information).
139- If structured information fails to find an answer then use the answer from unstructured information
140- and vice versa. I only want a straightforward answer without mentioning from which source you got the answer. You are also receiving
141- a chat history of the earlier conversation. You should be able to understand the context from the chat history and answer the question.
142- Given the user's query: { question } , provide a meaningful and efficient answer based
143- on the insights derived from the following data:
144- chat_summary:{ chat_summary }
145- Structured information: .
146- Unstructured information: { vector_res .get ('result' ,'' )} .
147226
227+ # final_prompt = f"""You are a helpful question-answering agent. Your task is to analyze
228+ # and synthesize information from two sources: the top result from a similarity search
229+ # (unstructured information) and relevant data from a graph database (structured information).
230+ # If structured information fails to find an answer then use the answer from unstructured information
231+ # and vice versa. I only want a straightforward answer without mentioning from which source you got the answer. You are also receiving
232+ # a chat history of the earlier conversation. You should be able to understand the context from the chat history and answer the question.
233+ # Given the user's query: {question}, provide a meaningful and efficient answer based
234+ # on the insights derived from the following data:
235+ # chat_summary:{chat_summary}
236+ # Structured information: .
237+ # Unstructured information: {vector_res.get('result','')}.
238+ # """
239+
240+ final_prompt = f"""
241+ You are an AI-powered question-answering agent tasked with providing accurate and direct responses to user queries. Utilize information from the chat history, current user input, and relevant unstructured data effectively.
242+
243+ Response Requirements:
244+ - Deliver concise and direct answers to the user's query without headers unless requested.
245+ - Acknowledge and utilize relevant previous interactions based on the chat history summary.
246+ - Respond to initial greetings appropriately, but avoid including a greeting in subsequent responses unless the chat is restarted or significantly paused.
247+ - Clearly state if an answer is unknown; avoid speculating.
248+
249+ Instructions:
250+ - Prioritize directly answering the User Input: { question } .
251+ - Use the Chat History Summary: { chat_summary } to provide context-aware responses.
252+ - Refer to Additional Unstructured Information: { vector_res .get ('result' , '' )} only if it directly relates to the query.
253+ - Cite sources clearly when using unstructured data in your response [Sources: { vector_res .get ('source' , '' )} ]. The Source must be printed only at the last in the format [Source: source1,source2]
254+ Ensure that answers are straightforward and context-aware, focusing on being relevant and concise.
148255 """
149256
150257 print (final_prompt )
258+ llm = get_llm (model = model )
151259 response = llm .predict (final_prompt )
260+ # print(response)
261+
152262 ai_message = response
153263 user_message = question
154264 save_chat_history (uri ,userName ,password ,session_id ,user_message ,ai_message )
155265
156- res = {"session_id" :session_id ,"message" :response ,"user" :"chatbot" }
266+ reponse = extract_and_remove_source (response )
267+ message = reponse ["message" ]
268+ sources = reponse ["sources" ]
269+ # print(extract_and_remove_source(response))
270+ print (response )
271+ res = {"session_id" :session_id ,"message" :message ,"sources" :sources ,"user" :"chatbot" }
157272 return res
158273 except Exception as e :
159274 error_message = str (e )
160275 logging .exception (f'Exception in in QA component:{ error_message } ' )
161- # raise Exception(error_message)
162- return {"session_id" :session_id ,"message" :"Something went wrong" ,"user" :"chatbot" }
276+ message = "Something went wrong"
277+ sources = []
278+ # raise Exception(error_message)
279+ return {"session_id" :session_id ,"message" :message ,"sources" :sources ,"user" :"chatbot" }
163280
164-
0 commit comments