2121from langchain_core .messages import HumanMessage ,AIMessage
2222from src .shared .constants import *
2323from src .llm import get_llm
24+ import json
2425
2526load_dotenv ()
2627
2728EMBEDDING_MODEL = os .getenv ('EMBEDDING_MODEL' )
2829EMBEDDING_FUNCTION , _ = load_embedding_model (EMBEDDING_MODEL )
2930
3031
31- def get_neo4j_retriever (graph , retrieval_query ,index_name = "vector" , search_k = CHAT_SEARCH_KWARG_K , score_threshold = CHAT_SEARCH_KWARG_SCORE_THRESHOLD ):
32+ def get_neo4j_retriever (graph , retrieval_query ,document_names , index_name = "vector" , search_k = CHAT_SEARCH_KWARG_K , score_threshold = CHAT_SEARCH_KWARG_SCORE_THRESHOLD ):
3233 try :
3334 neo_db = Neo4jVector .from_existing_index (
3435 embedding = EMBEDDING_FUNCTION ,
@@ -37,8 +38,13 @@ def get_neo4j_retriever(graph, retrieval_query,index_name="vector", search_k=CHA
3738 graph = graph
3839 )
3940 logging .info (f"Successfully retrieved Neo4jVector index '{ index_name } '" )
40- retriever = neo_db .as_retriever (search_kwargs = {'k' : search_k , "score_threshold" : score_threshold })
41- logging .info (f"Successfully created retriever for index '{ index_name } ' with search_k={ search_k } , score_threshold={ score_threshold } " )
41+ if document_names :
42+ document_names = list (map (str .strip , json .loads (document_names )))
43+ retriever = neo_db .as_retriever (search_kwargs = {'k' : search_k , "score_threshold" : score_threshold ,'filter' :{'fileName' : {'$in' : document_names }}})
44+ logging .info (f"Successfully created retriever for index '{ index_name } ' with search_k={ search_k } , score_threshold={ score_threshold } for documents { document_names } " )
45+ else :
46+ retriever = neo_db .as_retriever (search_kwargs = {'k' : search_k , "score_threshold" : score_threshold })
47+ logging .info (f"Successfully created retriever for index '{ index_name } ' with search_k={ search_k } , score_threshold={ score_threshold } " )
4248 return retriever
4349 except Exception as e :
4450 logging .error (f"Error retrieving Neo4jVector index '{ index_name } ' or creating retriever: { e } " )
@@ -198,13 +204,13 @@ def clear_chat_history(graph,session_id):
198204 "user" : "chatbot"
199205 }
200206
201- def setup_chat (model , graph , session_id , retrieval_query ):
207+ def setup_chat (model , graph , session_id , document_names , retrieval_query ):
202208 start_time = time .time ()
203209 if model in ["diffbot" , "LLM_MODEL_CONFIG_ollama_llama3" ]:
204210 model = "openai-gpt-4o"
205211 llm ,model_name = get_llm (model )
206212 logging .info (f"Model called in chat { model } and model version is { model_name } " )
207- retriever = get_neo4j_retriever (graph = graph ,retrieval_query = retrieval_query )
213+ retriever = get_neo4j_retriever (graph = graph ,retrieval_query = retrieval_query , document_names = document_names )
208214 doc_retriever = create_document_retriever_chain (llm , retriever )
209215 history = create_neo4j_chat_message_history (graph , session_id )
210216 chat_setup_time = time .time () - start_time
@@ -244,7 +250,7 @@ def summarize_and_log(history, messages, llm):
244250 history_summarized_time = time .time () - start_time
245251 logging .info (f"Chat History summarized in { history_summarized_time :.2f} seconds" )
246252
247- def QA_RAG (graph , model , question , session_id , mode ):
253+ def QA_RAG (graph , model , question , document_names , session_id , mode ):
248254 try :
249255 logging .info (f"Chat Mode : { mode } " )
250256 if mode == "vector" :
@@ -259,7 +265,7 @@ def QA_RAG(graph, model, question, session_id, mode):
259265 else :
260266 retrieval_query = VECTOR_GRAPH_SEARCH_QUERY
261267
262- llm , doc_retriever , history , model_version = setup_chat (model , graph , session_id , retrieval_query )
268+ llm , doc_retriever , history , model_version = setup_chat (model , graph , session_id , document_names , retrieval_query )
263269 messages = history .messages
264270 user_question = HumanMessage (content = question )
265271 messages .append (user_question )
0 commit comments