Skip to content

Commit e96fbf8

Browse files
Chat Document filter (#514)
* Added chat document filter * sending seleected filenames for chat response --------- Co-authored-by: vasanthasaikalluri <[email protected]>
1 parent f6fcef1 commit e96fbf8

File tree

8 files changed

+50
-32
lines changed

8 files changed

+50
-32
lines changed

backend/score.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,13 +282,13 @@ async def post_processing(uri=Form(None), userName=Form(None), password=Form(Non
282282
close_db_connection(graph, 'post_processing')
283283

284284
@app.post("/chat_bot")
285-
async def chat_bot(uri=Form(None),model=Form(None),userName=Form(None), password=Form(None), database=Form(None),question=Form(None), session_id=Form(None),mode=Form(None)):
285+
async def chat_bot(uri=Form(None),model=Form(None),userName=Form(None), password=Form(None), database=Form(None),question=Form(None), document_names=Form(None),session_id=Form(None),mode=Form(None)):
286286
logging.info(f"QA_RAG called at {datetime.now()}")
287287
qa_rag_start_time = time.time()
288288
try:
289289
# database = "neo4j"
290290
graph = create_graph_database_connection(uri, userName, password, database)
291-
result = await asyncio.to_thread(QA_RAG,graph=graph,model=model,question=question,session_id=session_id,mode=mode)
291+
result = await asyncio.to_thread(QA_RAG,graph=graph,model=model,question=question,document_names=document_names,session_id=session_id,mode=mode)
292292

293293
total_call_time = time.time() - qa_rag_start_time
294294
logging.info(f"Total Response time is {total_call_time:.2f} seconds")

backend/src/QA_integration_new.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@
2121
from langchain_core.messages import HumanMessage,AIMessage
2222
from src.shared.constants import *
2323
from src.llm import get_llm
24+
import json
2425

2526
load_dotenv()
2627

2728
EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL')
2829
EMBEDDING_FUNCTION , _ = load_embedding_model(EMBEDDING_MODEL)
2930

3031

31-
def get_neo4j_retriever(graph, retrieval_query,index_name="vector", search_k=CHAT_SEARCH_KWARG_K, score_threshold=CHAT_SEARCH_KWARG_SCORE_THRESHOLD):
32+
def get_neo4j_retriever(graph, retrieval_query,document_names,index_name="vector", search_k=CHAT_SEARCH_KWARG_K, score_threshold=CHAT_SEARCH_KWARG_SCORE_THRESHOLD):
3233
try:
3334
neo_db = Neo4jVector.from_existing_index(
3435
embedding=EMBEDDING_FUNCTION,
@@ -37,8 +38,13 @@ def get_neo4j_retriever(graph, retrieval_query,index_name="vector", search_k=CHA
3738
graph=graph
3839
)
3940
logging.info(f"Successfully retrieved Neo4jVector index '{index_name}'")
40-
retriever = neo_db.as_retriever(search_kwargs={'k': search_k, "score_threshold": score_threshold})
41-
logging.info(f"Successfully created retriever for index '{index_name}' with search_k={search_k}, score_threshold={score_threshold}")
41+
if document_names:
42+
document_names= list(map(str.strip, json.loads(document_names)))
43+
retriever = neo_db.as_retriever(search_kwargs={'k': search_k, "score_threshold": score_threshold,'filter':{'fileName': {'$in': document_names}}})
44+
logging.info(f"Successfully created retriever for index '{index_name}' with search_k={search_k}, score_threshold={score_threshold} for documents {document_names}")
45+
else:
46+
retriever = neo_db.as_retriever(search_kwargs={'k': search_k, "score_threshold": score_threshold})
47+
logging.info(f"Successfully created retriever for index '{index_name}' with search_k={search_k}, score_threshold={score_threshold}")
4248
return retriever
4349
except Exception as e:
4450
logging.error(f"Error retrieving Neo4jVector index '{index_name}' or creating retriever: {e}")
@@ -198,13 +204,13 @@ def clear_chat_history(graph,session_id):
198204
"user": "chatbot"
199205
}
200206

201-
def setup_chat(model, graph, session_id, retrieval_query):
207+
def setup_chat(model, graph, session_id, document_names,retrieval_query):
202208
start_time = time.time()
203209
if model in ["diffbot", "LLM_MODEL_CONFIG_ollama_llama3"]:
204210
model = "openai-gpt-4o"
205211
llm,model_name = get_llm(model)
206212
logging.info(f"Model called in chat {model} and model version is {model_name}")
207-
retriever = get_neo4j_retriever(graph=graph,retrieval_query=retrieval_query)
213+
retriever = get_neo4j_retriever(graph=graph,retrieval_query=retrieval_query,document_names=document_names)
208214
doc_retriever = create_document_retriever_chain(llm, retriever)
209215
history = create_neo4j_chat_message_history(graph, session_id)
210216
chat_setup_time = time.time() - start_time
@@ -244,7 +250,7 @@ def summarize_and_log(history, messages, llm):
244250
history_summarized_time = time.time() - start_time
245251
logging.info(f"Chat History summarized in {history_summarized_time:.2f} seconds")
246252

247-
def QA_RAG(graph, model, question, session_id, mode):
253+
def QA_RAG(graph, model, question, document_names,session_id, mode):
248254
try:
249255
logging.info(f"Chat Mode : {mode}")
250256
if mode == "vector":
@@ -259,7 +265,7 @@ def QA_RAG(graph, model, question, session_id, mode):
259265
else:
260266
retrieval_query = VECTOR_GRAPH_SEARCH_QUERY
261267

262-
llm, doc_retriever, history, model_version = setup_chat(model, graph, session_id, retrieval_query)
268+
llm, doc_retriever, history, model_version = setup_chat(model, graph, session_id, document_names,retrieval_query)
263269
messages = history.messages
264270
user_question = HumanMessage(content=question)
265271
messages.append(user_question)

frontend/src/components/ChatBot/Chatbot.tsx

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ const Chatbot: React.FC<ChatbotProps> = (props) => {
2525
const [inputMessage, setInputMessage] = useState('');
2626
const [loading, setLoading] = useState<boolean>(isLoading);
2727
const { userCredentials } = useCredentials();
28-
const { model, chatMode } = useFileContext();
28+
const { model, chatMode, selectedRows } = useFileContext();
2929
const messagesEndRef = useRef<HTMLDivElement>(null);
3030
const [sessionId, setSessionId] = useState<string>(sessionStorage.getItem('session_id') ?? '');
3131
const [showInfoModal, setShowInfoModal] = useState<boolean>(false);
@@ -44,6 +44,8 @@ const Chatbot: React.FC<ChatbotProps> = (props) => {
4444
},
4545
});
4646

47+
const selectedFileNames = selectedRows.map((str) => JSON.parse(str).name);
48+
4749
const handleInputChange = (e: React.ChangeEvent<HTMLInputElement>) => {
4850
setInputMessage(e.target.value);
4951
};
@@ -147,9 +149,15 @@ const Chatbot: React.FC<ChatbotProps> = (props) => {
147149
try {
148150
setInputMessage('');
149151
simulateTypingEffect({ reply: ' ' });
150-
const chatbotAPI = await chatBotAPI(userCredentials as UserCredentials, inputMessage, sessionId, model, chatMode);
152+
const chatbotAPI = await chatBotAPI(
153+
userCredentials as UserCredentials,
154+
inputMessage,
155+
sessionId,
156+
model,
157+
chatMode,
158+
selectedFileNames
159+
);
151160
const chatresponse = chatbotAPI?.response;
152-
console.log('api', chatresponse);
153161
chatbotReply = chatresponse?.data?.data?.message;
154162
chatSources = chatresponse?.data?.data?.info.sources;
155163
chatModel = chatresponse?.data?.data?.info.model;

frontend/src/components/Dropdown.tsx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,11 @@ const DropdownComponent: React.FC<ReusableDropdownProps> = ({
3737
const label =
3838
typeof option === 'string'
3939
? (option.includes('LLM_MODEL_CONFIG_')
40-
? capitalize(option.split('LLM_MODEL_CONFIG_').at(-1) as string)
41-
: capitalize(option)).split('_').join(' ')
40+
? capitalize(option.split('LLM_MODEL_CONFIG_').at(-1) as string)
41+
: capitalize(option)
42+
)
43+
.split('_')
44+
.join(' ')
4245
: capitalize(option.label);
4346
const value = typeof option === 'string' ? option : option.value;
4447
return {

frontend/src/components/FileTable.tsx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,10 @@ const FileTable: React.FC<FileTableProps> = ({ isExpanded, connectionStatus, set
287287
<i>
288288
{(model.includes('LLM_MODEL_CONFIG_')
289289
? capitalize(model.split('LLM_MODEL_CONFIG_').at(-1) as string)
290-
: capitalize(model)).split("_").join(" ")}
290+
: capitalize(model)
291+
)
292+
.split('_')
293+
.join(' ')}
291294
</i>
292295
);
293296
},

frontend/src/components/Graph/GraphViewModal.tsx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,8 @@ const GraphViewModal: React.FunctionComponent<GraphViewModalProps> = ({
233233
if (allNodes.length > 0 && allRelationships.length > 0) {
234234
const { filteredNodes, filteredRelations, filteredScheme } = filterData(
235235
graphType,
236-
finalNodes,
237-
finalRels,
236+
finalNodes ?? [],
237+
finalRels ?? [],
238238
schemeVal
239239
);
240240
setNodes(filteredNodes);
@@ -302,6 +302,10 @@ const GraphViewModal: React.FunctionComponent<GraphViewModalProps> = ({
302302
<div className='my-40 flex items-center justify-center'>
303303
<Banner name='graph banner' description={statusMessage} type={status} />
304304
</div>
305+
) : nodes.length === 0 || relationships.length === 0 ? (
306+
<div className='my-40 flex items-center justify-center'>
307+
<Banner name='graph banner' description='No Entities Found' type='danger' />
308+
</div>
305309
) : (
306310
<>
307311
<div className='flex' style={{ height: '100%' }}>

frontend/src/services/QnaAPI.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ export const chatBotAPI = async (
77
question: string,
88
session_id: string,
99
model: string,
10-
mode = 'vector'
10+
mode: string,
11+
document_names: string[]
1112
) => {
1213
try {
1314
const formData = new FormData();
@@ -19,6 +20,7 @@ export const chatBotAPI = async (
1920
formData.append('session_id', session_id);
2021
formData.append('model', model);
2122
formData.append('mode', mode);
23+
formData.append('document_names', JSON.stringify(document_names));
2224
const startTime = Date.now();
2325
const response = await axios.post(`${url()}/chat_bot`, formData, {
2426
headers: {

frontend/src/utils/Utils.ts

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,10 @@ export const filterData = (
183183
} else if (!graphType.includes('Document') && graphType.includes('Entities') && !graphType.includes('Chunk')) {
184184
// Only Entity
185185
// @ts-ignore
186-
filteredNodes = allNodes.filter((node) => !node.labels.includes('Document') && !node.labels.includes('Chunk'));
186+
const entityNode = allNodes.filter((node) => !node.labels.includes('Document') && !node.labels.includes('Chunk'));
187+
filteredNodes = entityNode ? entityNode : [];
187188
// @ts-ignore
188-
filteredRelations = allRelationships.filter(
189-
(rel) => !['PART_OF', 'FIRST_CHUNK', 'HAS_ENTITY', 'SIMILAR', 'NEXT_CHUNK'].includes(rel.caption)
190-
);
189+
filteredRelations = allRelationships.filter((rel) => !['PART_OF', 'FIRST_CHUNK', 'HAS_ENTITY', 'SIMILAR', 'NEXT_CHUNK'].includes(rel.caption));
191190
filteredScheme = Object.fromEntries(entityTypes.map((key) => [key, scheme[key]])) as Scheme;
192191
} else if (!graphType.includes('Document') && !graphType.includes('Entities') && graphType.includes('Chunk')) {
193192
// Only Chunk
@@ -199,22 +198,15 @@ export const filterData = (
199198
} else if (graphType.includes('Document') && graphType.includes('Entities') && !graphType.includes('Chunk')) {
200199
// Document + Entity
201200
// @ts-ignore
202-
filteredNodes = allNodes.filter(
203-
(node) =>
204-
node.labels.includes('Document') || (!node.labels.includes('Document') && !node.labels.includes('Chunk'))
205-
);
201+
filteredNodes = allNodes.filter((node) =>node.labels.includes('Document') || (!node.labels.includes('Document') && !node.labels.includes('Chunk')));
206202
// @ts-ignore
207-
filteredRelations = allRelationships.filter(
208-
(rel) => !['PART_OF', 'FIRST_CHUNK', 'HAS_ENTITY', 'SIMILAR', 'NEXT_CHUNK'].includes(rel.caption)
209-
);
203+
filteredRelations = allRelationships.filter((rel) => !['PART_OF', 'FIRST_CHUNK', 'HAS_ENTITY', 'SIMILAR', 'NEXT_CHUNK'].includes(rel.caption));
210204
} else if (graphType.includes('Document') && !graphType.includes('Entities') && graphType.includes('Chunk')) {
211205
// Document + Chunk
212206
// @ts-ignore
213207
filteredNodes = allNodes.filter((node) => node.labels.includes('Document') || node.labels.includes('Chunk'));
214208
// @ts-ignore
215-
filteredRelations = allRelationships.filter((rel) =>
216-
['PART_OF', 'FIRST_CHUNK', 'SIMILAR', 'NEXT_CHUNK'].includes(rel.caption)
217-
);
209+
filteredRelations = allRelationships.filter((rel) =>['PART_OF', 'FIRST_CHUNK', 'SIMILAR', 'NEXT_CHUNK'].includes(rel.caption));
218210
filteredScheme = { Document: scheme.Document, Chunk: scheme.Chunk };
219211
} else if (!graphType.includes('Document') && graphType.includes('Entities') && graphType.includes('Chunk')) {
220212
// Chunk + Entity

0 commit comments

Comments
 (0)