Skip to content

Commit f6c29b1

Browse files
Graph chat (#525)
* added graph mode * added doc names parameter * added the graph mode * chat modes configurable and added the tabs for graph only mode * changed the example env * cypher cleaning * fixed the default tab mode for graph mode --------- Co-authored-by: vasanthasaikalluri <[email protected]>
1 parent a24c223 commit f6c29b1

File tree

11 files changed

+256
-109
lines changed

11 files changed

+256
-109
lines changed

backend/score.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,10 @@ async def chat_bot(uri=Form(None),model=Form(None),userName=Form(None), password
286286
logging.info(f"QA_RAG called at {datetime.now()}")
287287
qa_rag_start_time = time.time()
288288
try:
289-
# database = "neo4j"
290-
graph = create_graph_database_connection(uri, userName, password, database)
289+
if mode == "graph":
290+
graph = Neo4jGraph( url=uri,username=userName,password=password,database=database,sanitize = True, refresh_schema=True)
291+
else:
292+
graph = create_graph_database_connection(uri, userName, password, database)
291293
result = await asyncio.to_thread(QA_RAG,graph=graph,model=model,question=question,document_names=document_names,session_id=session_id,mode=mode)
292294

293295
total_call_time = time.time() - qa_rag_start_time

backend/src/QA_integration_new.py

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from langchain_core.messages import HumanMessage,AIMessage
2222
from src.shared.constants import *
2323
from src.llm import get_llm
24+
from langchain.chains import GraphCypherQAChain
2425
import json
2526

2627
## Chat models
@@ -226,11 +227,10 @@ def setup_chat(model, graph, session_id, document_names,retrieval_query):
226227
logging.info(f"Model called in chat {model} and model version is {model_name}")
227228
retriever = get_neo4j_retriever(graph=graph,retrieval_query=retrieval_query,document_names=document_names)
228229
doc_retriever = create_document_retriever_chain(llm, retriever)
229-
history = create_neo4j_chat_message_history(graph, session_id)
230230
chat_setup_time = time.time() - start_time
231231
logging.info(f"Chat setup completed in {chat_setup_time:.2f} seconds")
232232

233-
return llm, doc_retriever, history, model_name
233+
return llm, doc_retriever, model_name
234234

235235
def retrieve_documents(doc_retriever, messages):
236236
start_time = time.time()
@@ -264,25 +264,87 @@ def summarize_and_log(history, messages, llm):
264264
history_summarized_time = time.time() - start_time
265265
logging.info(f"Chat History summarized in {history_summarized_time:.2f} seconds")
266266

267+
268+
def create_graph_chain(model, graph):
269+
try:
270+
logging.info(f"Graph QA Chain using LLM model: {model}")
271+
272+
cypher_llm,model_name = get_llm(model)
273+
qa_llm,model_name = get_llm(model)
274+
graph_chain = GraphCypherQAChain.from_llm(
275+
cypher_llm=cypher_llm,
276+
qa_llm=qa_llm,
277+
validate_cypher= True,
278+
graph=graph,
279+
# verbose=True,
280+
return_intermediate_steps = True,
281+
top_k=3
282+
)
283+
284+
logging.info("GraphCypherQAChain instance created successfully.")
285+
return graph_chain,qa_llm,model_name
286+
287+
except Exception as e:
288+
logging.error(f"An error occurred while creating the GraphCypherQAChain instance. : {e}")
289+
290+
291+
def get_graph_response(graph_chain, question):
292+
try:
293+
cypher_res = graph_chain.invoke({"query": question})
294+
295+
response = cypher_res.get("result")
296+
cypher_query = ""
297+
context = []
298+
299+
for step in cypher_res.get("intermediate_steps", []):
300+
if "query" in step:
301+
cypher_string = step["query"]
302+
cypher_query = cypher_string.replace("cypher\n", "").replace("\n", " ").strip()
303+
elif "context" in step:
304+
context = step["context"]
305+
return {
306+
"response": response,
307+
"cypher_query": cypher_query,
308+
"context": context
309+
}
310+
311+
except Exception as e:
312+
logging.error("An error occurred while getting the graph response : {e}")
313+
267314
def QA_RAG(graph, model, question, document_names,session_id, mode):
268315
try:
269316
logging.info(f"Chat Mode : {mode}")
270-
if mode == "vector":
271-
retrieval_query = VECTOR_SEARCH_QUERY
272-
elif mode == "graph":
273-
#WIP
274-
result = {
275-
"session_id": session_id,
317+
history = create_neo4j_chat_message_history(graph, session_id)
318+
messages = history.messages
319+
user_question = HumanMessage(content=question)
320+
messages.append(user_question)
321+
322+
if mode == "graph":
323+
graph_chain, qa_llm,model_version = create_graph_chain(model,graph)
324+
graph_response = get_graph_response(graph_chain,question)
325+
ai_response = AIMessage(content=graph_response["response"])
326+
messages.append(ai_response)
327+
summarize_and_log(history, messages, qa_llm)
328+
329+
result = {
330+
"session_id": session_id,
331+
"message": graph_response["response"],
332+
"info": {
333+
"model": model_version,
334+
'cypher_query':graph_response["cypher_query"],
335+
"context" : graph_response["context"],
336+
"mode" : mode,
337+
"response_time": 0
338+
},
276339
"user": "chatbot"
277-
}
340+
}
278341
return result
342+
elif mode == "vector":
343+
retrieval_query = VECTOR_SEARCH_QUERY
279344
else:
280345
retrieval_query = VECTOR_GRAPH_SEARCH_QUERY
281346

282-
llm, doc_retriever, history, model_version = setup_chat(model, graph, session_id, document_names,retrieval_query)
283-
messages = history.messages
284-
user_question = HumanMessage(content=question)
285-
messages.append(user_question)
347+
llm, doc_retriever, model_version = setup_chat(model, graph, session_id, document_names,retrieval_query)
286348

287349
docs = retrieve_documents(doc_retriever, messages)
288350

frontend/example.env

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ TIME_PER_CHUNK=4
77
TIME_PER_PAGE=50
88
CHUNK_SIZE=5242880
99
LARGE_FILE_SIZE=5242880
10-
GOOGLE_CLIENT_ID=""
10+
GOOGLE_CLIENT_ID=""
11+
CHAT_MODES=""

frontend/src/components/ChatBot/ChatInfoModal.tsx

Lines changed: 125 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
1-
import { Box, Typography, TextLink, Flex, Tabs, LoadingSpinner } from '@neo4j-ndl/react';
2-
import { DocumentTextIconOutline } from '@neo4j-ndl/react/icons';
1+
import {
2+
Box,
3+
Typography,
4+
TextLink,
5+
Flex,
6+
Tabs,
7+
LoadingSpinner,
8+
CypherCodeBlock,
9+
CypherCodeBlockProps,
10+
useCopyToClipboard,
11+
} from '@neo4j-ndl/react';
12+
import { DocumentDuplicateIconOutline, DocumentTextIconOutline } from '@neo4j-ndl/react/icons';
313
import '../../styling/info.css';
414
import Neo4jRetrievalLogo from '../../assets/images/Neo4jRetrievalLogo.png';
515
import wikipedialogo from '../../assets/images/wikipedia.svg';
616
import youtubelogo from '../../assets/images/youtube.svg';
717
import gcslogo from '../../assets/images/gcs.webp';
818
import s3logo from '../../assets/images/s3logo.png';
919
import { Chunk, Entity, GroupedEntity, UserCredentials, chatInfoMessage } from '../../types';
10-
import { useEffect, useMemo, useState } from 'react';
20+
import { useContext, useEffect, useMemo, useState } from 'react';
1121
import HoverableLink from '../UI/HoverableLink';
1222
import GraphViewButton from '../Graph/GraphViewButton';
1323
import { chunkEntitiesAPI } from '../../services/ChunkEntitiesInfo';
@@ -17,44 +27,88 @@ import { calcWordColor } from '@neo4j-devtools/word-color';
1727
import ReactMarkdown from 'react-markdown';
1828
import { GlobeAltIconOutline } from '@neo4j-ndl/react/icons';
1929
import { youtubeLinkValidation } from '../../utils/Utils';
30+
import { ThemeWrapperContext } from '../../context/ThemeWrapper';
31+
import { ClipboardDocumentCheckIconOutline } from '@neo4j-ndl/react/icons';
2032

21-
const ChatInfoModal: React.FC<chatInfoMessage> = ({ sources, model, total_tokens, response_time, chunk_ids, mode }) => {
22-
const [activeTab, setActiveTab] = useState<number>(3);
33+
const ChatInfoModal: React.FC<chatInfoMessage> = ({
34+
sources,
35+
model,
36+
total_tokens,
37+
response_time,
38+
chunk_ids,
39+
mode,
40+
cypher_query,
41+
graphonly_entities,
42+
}) => {
43+
const [activeTab, setActiveTab] = useState<number>(mode === 'graph' ? 4 : 3);
2344
const [infoEntities, setInfoEntities] = useState<Entity[]>([]);
2445
const [loading, setLoading] = useState<boolean>(false);
2546
const { userCredentials } = useCredentials();
2647
const [nodes, setNodes] = useState<Node[]>([]);
2748
const [relationships, setRelationships] = useState<Relationship[]>([]);
2849
const [chunks, setChunks] = useState<Chunk[]>([]);
50+
const themeUtils = useContext(ThemeWrapperContext);
51+
const [, copy] = useCopyToClipboard();
52+
const [copiedText, setcopiedText] = useState<boolean>(false);
53+
2954
const parseEntity = (entity: Entity) => {
3055
const { labels, properties } = entity;
3156
const label = labels[0];
3257
const text = properties.id;
3358
return { label, text };
3459
};
60+
const actions: CypherCodeBlockProps['actions'] = useMemo(
61+
() => [
62+
{
63+
title: 'copy',
64+
'aria-label': 'copy',
65+
children: (
66+
<>
67+
{copiedText ? (
68+
<ClipboardDocumentCheckIconOutline className='n-size-token-7' />
69+
) : (
70+
<DocumentDuplicateIconOutline className='text-palette-neutral-text-icon' />
71+
)}
72+
</>
73+
),
74+
onClick: () => {
75+
void copy(cypher_query as string);
76+
setcopiedText(true);
77+
},
78+
},
79+
],
80+
[copiedText, cypher_query]
81+
);
3582
useEffect(() => {
36-
setLoading(true);
37-
chunkEntitiesAPI(userCredentials as UserCredentials, chunk_ids.map((c) => c.id).join(','))
38-
.then((response) => {
39-
setInfoEntities(response.data.data.nodes);
40-
setNodes(response.data.data.nodes);
41-
setRelationships(response.data.data.relationships);
42-
const chunks = response.data.data.chunk_data.map((chunk: any) => {
43-
const chunkScore = chunk_ids.find((chunkdetail) => chunkdetail.id === chunk.id);
44-
return {
45-
...chunk,
46-
score: chunkScore?.score,
47-
};
83+
if (mode != 'graph') {
84+
setLoading(true);
85+
chunkEntitiesAPI(userCredentials as UserCredentials, chunk_ids.map((c) => c.id).join(','))
86+
.then((response) => {
87+
console.log({ response });
88+
setInfoEntities(response.data.data.nodes);
89+
setNodes(response.data.data.nodes);
90+
setRelationships(response.data.data.relationships);
91+
const chunks = response.data.data.chunk_data.map((chunk: any) => {
92+
const chunkScore = chunk_ids.find((chunkdetail) => chunkdetail.id === chunk.id);
93+
return {
94+
...chunk,
95+
score: chunkScore?.score,
96+
};
97+
});
98+
const sortedchunks = chunks.sort((a: any, b: any) => b.score - a.score);
99+
setChunks(sortedchunks);
100+
setLoading(false);
101+
})
102+
.catch((error) => {
103+
console.error('Error fetching entities:', error);
104+
setLoading(false);
48105
});
49-
const sortedchunks = chunks.sort((a: any, b: any) => b.score - a.score);
50-
setChunks(sortedchunks);
51-
setLoading(false);
52-
})
53-
.catch((error) => {
54-
console.error('Error fetching entities:', error);
55-
setLoading(false);
56-
});
57-
}, [chunk_ids]);
106+
}
107+
108+
() => {
109+
setcopiedText(false);
110+
};
111+
}, [chunk_ids, mode]);
58112
const groupedEntities = useMemo<{ [key: string]: GroupedEntity }>(() => {
59113
return infoEntities.reduce((acc, entity) => {
60114
const { label, text } = parseEntity(entity);
@@ -107,9 +161,10 @@ const ChatInfoModal: React.FC<chatInfoMessage> = ({ sources, model, total_tokens
107161
</Box>
108162
</Box>
109163
<Tabs size='large' fill='underline' onChange={onChangeTabs} value={activeTab}>
110-
<Tabs.Tab tabId={3}>Sources used</Tabs.Tab>
111-
{mode === 'graph+vector' && <Tabs.Tab tabId={4}>Top Entities used</Tabs.Tab>}
112-
<Tabs.Tab tabId={5}>Chunks</Tabs.Tab>
164+
{mode != 'graph' && <Tabs.Tab tabId={3}>Sources used</Tabs.Tab>}
165+
{(mode === 'graph+vector' || mode === 'graph') && <Tabs.Tab tabId={4}>Top Entities used</Tabs.Tab>}
166+
{mode === 'graph' && cypher_query?.trim().length && <Tabs.Tab tabId={6}>Generated Cypher Query</Tabs.Tab>}
167+
{mode != 'graph' && <Tabs.Tab tabId={5}>Chunks</Tabs.Tab>}
113168
</Tabs>
114169
<Flex className='p-4'>
115170
<Tabs.TabPanel className='n-flex n-flex-col n-gap-token-4 n-p-token-6' value={activeTab} tabId={3}>
@@ -225,28 +280,39 @@ const ChatInfoModal: React.FC<chatInfoMessage> = ({ sources, model, total_tokens
225280
<Box className='flex justify-center items-center'>
226281
<LoadingSpinner size='small' />
227282
</Box>
228-
) : Object.keys(groupedEntities).length > 0 ? (
283+
) : Object.keys(groupedEntities).length > 0 || Object.keys(graphonly_entities).length > 0 ? (
229284
<ul className='list-none p-4 max-h-80 overflow-auto'>
230-
{sortedLabels.map((label, index) => (
231-
<li
232-
key={index}
233-
className='flex items-center mb-2 text-ellipsis whitespace-nowrap max-w-[100%)] overflow-hidden'
234-
>
235-
<div
236-
key={index}
237-
style={{ backgroundColor: `${groupedEntities[label].color}` }}
238-
className='legend mr-2'
239-
>
240-
{label} ({labelCounts[label]})
241-
</div>
242-
<Typography
243-
className='entity-text text-ellipsis whitespace-nowrap max-w-[calc(100%-120px)] overflow-hidden'
244-
variant='body-medium'
245-
>
246-
{Array.from(groupedEntities[label].texts).slice(0, 3).join(', ')}
247-
</Typography>
248-
</li>
249-
))}
285+
{mode == 'graph'
286+
? graphonly_entities.map((label, index) => (
287+
<li
288+
key={index}
289+
className='flex items-center mb-2 text-ellipsis whitespace-nowrap max-w-[100%)] overflow-hidden'
290+
>
291+
<div style={{ backgroundColor: calcWordColor(Object.keys(label)[0]) }} className='legend mr-2'>
292+
{Object.keys(label)[0]}
293+
</div>
294+
</li>
295+
))
296+
: sortedLabels.map((label, index) => (
297+
<li
298+
key={index}
299+
className='flex items-center mb-2 text-ellipsis whitespace-nowrap max-w-[100%)] overflow-hidden'
300+
>
301+
<div
302+
key={index}
303+
style={{ backgroundColor: `${groupedEntities[label].color}` }}
304+
className='legend mr-2'
305+
>
306+
{label} ({labelCounts[label]})
307+
</div>
308+
<Typography
309+
className='entity-text text-ellipsis whitespace-nowrap max-w-[calc(100%-120px)] overflow-hidden'
310+
variant='body-medium'
311+
>
312+
{Array.from(groupedEntities[label].texts).slice(0, 3).join(', ')}
313+
</Typography>
314+
</li>
315+
))}
250316
</ul>
251317
) : (
252318
<span className='h6 text-center'>No Entities Found</span>
@@ -340,6 +406,15 @@ const ChatInfoModal: React.FC<chatInfoMessage> = ({ sources, model, total_tokens
340406
<span className='h6 text-center'>No Chunks Found</span>
341407
)}
342408
</Tabs.TabPanel>
409+
<Tabs.TabPanel value={activeTab} tabId={6}>
410+
<CypherCodeBlock
411+
code={cypher_query as string}
412+
actions={actions}
413+
headerTitle=''
414+
theme={themeUtils.colorMode}
415+
className='min-h-40'
416+
/>
417+
</Tabs.TabPanel>
343418
</Flex>
344419
{activeTab == 4 && nodes.length && relationships.length ? (
345420
<Box className='button-container flex mt-2 justify-center'>

0 commit comments

Comments
 (0)