Skip to content

Commit 7c07230

Browse files
Merge pull request #220 from neo4j-labs/configurable_chatbot
Configurable chatbot
2 parents e692f50 + 7a4c775 commit 7c07230

File tree

7 files changed

+207
-44
lines changed

7 files changed

+207
-44
lines changed

backend/score.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,11 @@ async def update_similarity_graph(uri=Form(None), userName=Form(None), password=
213213
return create_api_response(job_status,message=message,error=error_message)
214214

215215
@app.post("/chat_bot")
216-
async def chat_bot(uri=Form(None), userName=Form(None), password=Form(None), question=Form(None), session_id=Form(None)):
216+
async def chat_bot(uri=Form(None),model=Form(None),userName=Form(None), password=Form(None), question=Form(None), session_id=Form(None)):
217217
try:
218-
result = await asyncio.to_thread(QA_RAG,uri=uri,userName=userName,password=password,question=question,session_id=session_id)
218+
# model=Form(None),
219+
# model = "Gemini Pro"
220+
result = await asyncio.to_thread(QA_RAG,uri=uri,model=model,userName=userName,password=password,question=question,session_id=session_id)
219221
return create_api_response('Success',data=result)
220222
except Exception as e:
221223
job_status = "Failed"

backend/src/QA_integration.py

Lines changed: 137 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,79 @@
77
from langchain_openai import ChatOpenAI
88
from langchain_openai import OpenAIEmbeddings
99
from langchain_google_vertexai import VertexAIEmbeddings
10+
from langchain_google_vertexai import ChatVertexAI
11+
from langchain_google_vertexai import HarmBlockThreshold, HarmCategory
1012
import logging
1113
from langchain_community.chat_message_histories import Neo4jChatMessageHistory
14+
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
15+
from src.shared.common_fn import load_embedding_model
16+
import re
17+
1218
load_dotenv()
1319

1420
openai_api_key = os.environ.get('OPENAI_API_KEY')
15-
model_version='gpt-4-0125-preview'
21+
22+
23+
# def get_embedding_function(embedding_model_name: str):
24+
# if embedding_model_name == "openai":
25+
# embedding_function = OpenAIEmbeddings()
26+
# dimension = 1536
27+
# logging.info(f"Embedding: Using OpenAI Embeddings , Dimension:{dimension}")
28+
# elif embedding_model_name == "vertexai":
29+
# embedding_function = VertexAIEmbeddings(
30+
# model_name="textembedding-gecko@003"
31+
# )
32+
# dimension = 768
33+
# logging.info(f"Embedding: Using Vertex AI Embeddings , Dimension:{dimension}")
34+
# else:
35+
# embedding_function = SentenceTransformerEmbeddings(
36+
# model_name="all-MiniLM-L6-v2"#, cache_folder="/embedding_model"
37+
# )
38+
# dimension = 384
39+
# logging.info(f"Embedding: Using SentenceTransformer , Dimension:{dimension}")
40+
# return embedding_function
41+
42+
def get_llm(model : str):
43+
if model == "OpenAI GPT 3.5":
44+
model_version = "gpt-3.5-turbo-16k"
45+
logging.info(f"Chat Model: GPT 3.5, Model Version : {model_version}")
46+
llm = ChatOpenAI(model= model_version, temperature=0)
47+
48+
elif model == "Gemini Pro" :
49+
# model_version = "gemini-1.0-pro"
50+
model_version = 'gemini-1.0-pro-001'
51+
logging.info(f"Chat Model: Gemini , Model Version : {model_version}")
52+
llm = ChatVertexAI(model_name=model_version,
53+
# max_output_tokens=100,
54+
convert_system_message_to_human=True,
55+
temperature=0,
56+
safety_settings={
57+
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
58+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
59+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
60+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
61+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
62+
})
63+
elif model == "Gemini 1.5 Pro" :
64+
model_version = "gemini-1.5-pro-preview-0409"
65+
logging.info(f"Chat Model: Gemini 1.5 , Model Version : {model_version}")
66+
llm = ChatVertexAI(model_name=model_version,
67+
# max_output_tokens=100,
68+
convert_system_message_to_human=True,
69+
temperature=0,
70+
safety_settings={
71+
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
72+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
73+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
74+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
75+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
76+
})
77+
else:
78+
## for model == "OpenAI GPT 4" or model == "Diffbot"
79+
model_version = "gpt-4-0125-preview"
80+
logging.info(f"Chat Model: GPT 4, Model Version : {model_version}")
81+
llm = ChatOpenAI(model= model_version, temperature=0)
82+
return llm
1683

1784
def vector_embed_results(qa,question):
1885
vector_res={}
@@ -92,29 +159,52 @@ def get_chat_history(llm,uri,userName,password,session_id):
92159
error_message = str(e)
93160
logging.exception(f'Exception in retrieving chat history:{error_message}')
94161
# raise Exception(error_message)
95-
return ''
162+
return ''
163+
164+
def extract_and_remove_source(message):
165+
pattern = r'\[Source: ([^\]]+)\]'
166+
match = re.search(pattern, message)
167+
if match:
168+
sources_string = match.group(1)
169+
sources = [source.strip().strip("'") for source in sources_string.split(',')]
170+
new_message = re.sub(pattern, '', message).strip()
171+
response = {
172+
"message" : new_message,
173+
"sources" : sources
174+
}
175+
else:
176+
response = {
177+
"message" : message,
178+
"sources" : []
179+
}
180+
return response
96181

97-
def QA_RAG(uri,userName,password,question,session_id):
182+
def QA_RAG(uri,model,userName,password,question,session_id):
98183
try:
99184
retrieval_query="""
100185
MATCH (node)-[:PART_OF]->(d:Document)
101186
WITH d, apoc.text.join(collect(node.text),"\n----\n") as text, avg(score) as score
102187
RETURN text, score, {source: COALESCE(CASE WHEN d.url CONTAINS "None" THEN d.fileName ELSE d.url END, d.fileName)} as metadata
103188
"""
104189
embedding_model = os.getenv('EMBEDDING_MODEL')
190+
embedding_function, _ = load_embedding_model(embedding_model)
105191
neo_db=Neo4jVector.from_existing_index(
106-
embedding = VertexAIEmbeddings(model_name=embedding_model),
192+
embedding = embedding_function,
107193
url=uri,
108194
username=userName,
109195
password=password,
110196
database="neo4j",
111197
index_name="vector",
112198
retrieval_query=retrieval_query,
113199
)
114-
llm = ChatOpenAI(model= model_version, temperature=0)
200+
# model = "Gemini Pro"
201+
llm = get_llm(model = model)
115202

116203
qa = RetrievalQA.from_chain_type(
117-
llm=llm, chain_type="stuff", retriever=neo_db.as_retriever(search_kwargs={'k': 3,"score_threshold": 0.5}), return_source_documents=True
204+
llm=llm,
205+
chain_type="stuff",
206+
retriever=neo_db.as_retriever(search_kwargs={'k': 3,"score_threshold": 0.5}),
207+
return_source_documents=True
118208
)
119209

120210
vector_res=vector_embed_results(qa,question)
@@ -133,32 +223,58 @@ def QA_RAG(uri,userName,password,question,session_id):
133223

134224
chat_summary=get_chat_history(llm,uri,userName,password,session_id)
135225

136-
final_prompt = f"""You are a helpful question-answering agent. Your task is to analyze
137-
and synthesize information from two sources: the top result from a similarity search
138-
(unstructured information) and relevant data from a graph database (structured information).
139-
If structured information fails to find an answer then use the answer from unstructured information
140-
and vice versa. I only want a straightforward answer without mentioning from which source you got the answer. You are also receiving
141-
a chat history of the earlier conversation. You should be able to understand the context from the chat history and answer the question.
142-
Given the user's query: {question}, provide a meaningful and efficient answer based
143-
on the insights derived from the following data:
144-
chat_summary:{chat_summary}
145-
Structured information: .
146-
Unstructured information: {vector_res.get('result','')}.
147226

227+
# final_prompt = f"""You are a helpful question-answering agent. Your task is to analyze
228+
# and synthesize information from two sources: the top result from a similarity search
229+
# (unstructured information) and relevant data from a graph database (structured information).
230+
# If structured information fails to find an answer then use the answer from unstructured information
231+
# and vice versa. I only want a straightforward answer without mentioning from which source you got the answer. You are also receiving
232+
# a chat history of the earlier conversation. You should be able to understand the context from the chat history and answer the question.
233+
# Given the user's query: {question}, provide a meaningful and efficient answer based
234+
# on the insights derived from the following data:
235+
# chat_summary:{chat_summary}
236+
# Structured information: .
237+
# Unstructured information: {vector_res.get('result','')}.
238+
# """
239+
240+
final_prompt = f"""
241+
You are an AI-powered question-answering agent tasked with providing accurate and direct responses to user queries. Utilize information from the chat history, current user input, and relevant unstructured data effectively.
242+
243+
Response Requirements:
244+
- Deliver concise and direct answers to the user's query without headers unless requested.
245+
- Acknowledge and utilize relevant previous interactions based on the chat history summary.
246+
- Respond to initial greetings appropriately, but avoid including a greeting in subsequent responses unless the chat is restarted or significantly paused.
247+
- Clearly state if an answer is unknown; avoid speculating.
248+
249+
Instructions:
250+
- Prioritize directly answering the User Input: {question}.
251+
- Use the Chat History Summary: {chat_summary} to provide context-aware responses.
252+
- Refer to Additional Unstructured Information: {vector_res.get('result', '')} only if it directly relates to the query.
253+
- Cite sources clearly when using unstructured data in your response [Sources: {vector_res.get('source', '')}]. The Source must be printed only at the last in the format [Source: source1,source2]
254+
Ensure that answers are straightforward and context-aware, focusing on being relevant and concise.
148255
"""
149256

150257
print(final_prompt)
258+
llm = get_llm(model = model)
151259
response = llm.predict(final_prompt)
260+
# print(response)
261+
152262
ai_message=response
153263
user_message=question
154264
save_chat_history(uri,userName,password,session_id,user_message,ai_message)
155265

156-
res={"session_id":session_id,"message":response,"user":"chatbot"}
266+
reponse = extract_and_remove_source(response)
267+
message = reponse["message"]
268+
sources = reponse["sources"]
269+
# print(extract_and_remove_source(response))
270+
print(response)
271+
res={"session_id":session_id,"message":message,"sources":sources,"user":"chatbot"}
157272
return res
158273
except Exception as e:
159274
error_message = str(e)
160275
logging.exception(f'Exception in in QA component:{error_message}')
161-
# raise Exception(error_message)
162-
return {"session_id":session_id,"message":"Something went wrong","user":"chatbot"}
276+
message = "Something went wrong"
277+
sources = []
278+
# raise Exception(error_message)
279+
return {"session_id":session_id,"message":message,"sources":sources,"user":"chatbot"}
163280

164-

backend/src/shared/common_fn.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import logging
22
from src.document_sources.youtube import create_youtube_url
3+
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
4+
from langchain_google_vertexai import VertexAIEmbeddings
5+
from langchain_openai import OpenAIEmbeddings
36
from langchain.docstore.document import Document
47
import re
58
import os
@@ -56,4 +59,24 @@ def get_chunk_and_graphDocument(graph_document_list, chunkId_chunkDoc_list):
5659
lst_chunk_chunkId_document.append({'graph_doc':graph_document,'chunk_id':chunk_id})
5760

5861
return lst_chunk_chunkId_document
62+
63+
64+
def load_embedding_model(embedding_model_name: str):
65+
if embedding_model_name == "openai":
66+
embeddings = OpenAIEmbeddings()
67+
dimension = 1536
68+
logging.info(f"Embedding: Using OpenAI Embeddings , Dimension:{dimension}")
69+
elif embedding_model_name == "vertexai":
70+
embeddings = VertexAIEmbeddings(
71+
model="textembedding-gecko@003"
72+
)
73+
dimension = 768
74+
logging.info(f"Embedding: Using Vertex AI Embeddings , Dimension:{dimension}")
75+
else:
76+
embeddings = SentenceTransformerEmbeddings(
77+
model_name="all-MiniLM-L6-v2"#, cache_folder="/embedding_model"
78+
)
79+
dimension = 384
80+
logging.info(f"Embedding: Using SentenceTransformer , Dimension:{dimension}")
81+
return embeddings, dimension
5982

frontend/src/assets/ChatbotMessages.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
"id": 2,
1111
"message": " Welcome to the Neo4j Knowledge Graph Chat. You can ask questions related to documents which have been completely processed.",
1212
"user": "chatbot",
13-
"datetime": "01/01/2024 00:00:00"
13+
"datetime": "01/01/2024 00:00:00",
14+
"sources":["https://neo4j.com/"]
1415
},
1516
{
1617
"id": 3,

frontend/src/components/Chatbot.tsx

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
/* eslint-disable no-confusing-arrow */
22
import { useEffect, useRef, useState } from 'react';
3-
import { Button, Widget, Typography, Avatar, TextInput } from '@neo4j-ndl/react';
3+
import { Button, Widget, Typography, Avatar, TextInput, TextLink } from '@neo4j-ndl/react';
44
import ChatBotUserAvatar from '../assets/images/chatbot-user.png';
55
import ChatBotAvatar from '../assets/images/chatbot-ai.png';
66
import { ChatbotProps, UserCredentials } from '../types';
77
import { useCredentials } from '../context/UserCredentials';
88
import chatBotAPI from '../services/QnaAPI';
99
import { v4 as uuidv4 } from 'uuid';
10+
import { useFileContext } from '../context/UsersFiles';
1011

1112
export default function Chatbot(props: ChatbotProps) {
1213
const { messages: listMessages, setMessages: setListMessages } = props;
1314
const [inputMessage, setInputMessage] = useState('');
1415
const formattedTextStyle = { color: 'rgb(var(--theme-palette-discovery-bg-strong))' };
1516
const [loading, setLoading] = useState<boolean>(false);
1617
const { userCredentials } = useCredentials();
18+
const { model } = useFileContext();
1719
const messagesEndRef = useRef<HTMLDivElement>(null);
1820
const [sessionId, setSessionId] = useState<string>(sessionStorage.getItem('session_id') ?? '');
1921

@@ -29,17 +31,24 @@ export default function Chatbot(props: ChatbotProps) {
2931
}
3032
}, []);
3133

32-
const simulateTypingEffect = (responseText: string, index = 0) => {
33-
if (index < responseText.length) {
34+
const simulateTypingEffect = (response: { reply: string; sources?: [string] }, index = 0) => {
35+
if (index < response.reply.length) {
3436
const nextIndex = index + 1;
35-
const currentTypedText = responseText.substring(0, nextIndex);
37+
const currentTypedText = response.reply.substring(0, nextIndex);
3638
if (index === 0) {
3739
const date = new Date();
3840
const datetime = `${date.toLocaleDateString()} ${date.toLocaleTimeString()}`;
39-
if (responseText.length <= 1) {
41+
if (response.reply.length <= 1) {
4042
setListMessages((msgs) => [
4143
...msgs,
42-
{ id: Date.now(), user: 'chatbot', message: currentTypedText, datetime: datetime, isTyping: true },
44+
{
45+
id: Date.now(),
46+
user: 'chatbot',
47+
message: currentTypedText,
48+
datetime: datetime,
49+
isTyping: true,
50+
sources: response?.sources,
51+
},
4352
]);
4453
} else {
4554
setListMessages((msgs) => {
@@ -49,6 +58,7 @@ export default function Chatbot(props: ChatbotProps) {
4958
lastmsg.message = currentTypedText;
5059
lastmsg.datetime = datetime;
5160
lastmsg.isTyping = true;
61+
lastmsg.sources = response?.sources;
5262
return msgs.map((msg, index) => {
5363
if (index === msgs.length - 1) {
5464
return lastmsg;
@@ -60,7 +70,7 @@ export default function Chatbot(props: ChatbotProps) {
6070
} else {
6171
setListMessages((msgs) => msgs.map((msg) => (msg.isTyping ? { ...msg, message: currentTypedText } : msg)));
6272
}
63-
setTimeout(() => simulateTypingEffect(responseText, nextIndex), 20);
73+
setTimeout(() => simulateTypingEffect(response, nextIndex), 20);
6474
} else {
6575
setListMessages((msgs) => msgs.map((msg) => (msg.isTyping ? { ...msg, isTyping: false } : msg)));
6676
}
@@ -79,15 +89,15 @@ export default function Chatbot(props: ChatbotProps) {
7989
try {
8090
setLoading(true);
8191
setInputMessage('');
82-
simulateTypingEffect(' ');
83-
const chatresponse = await chatBotAPI(userCredentials as UserCredentials, inputMessage, sessionId);
84-
chatbotReply = chatresponse?.data?.message;
85-
simulateTypingEffect(chatbotReply);
92+
simulateTypingEffect({ reply: ' ' });
93+
const chatresponse = await chatBotAPI(userCredentials as UserCredentials, inputMessage, sessionId, model);
94+
chatbotReply = chatresponse?.data?.data?.message;
95+
simulateTypingEffect({ reply: chatbotReply, sources: chatresponse?.data?.data?.sources });
8696
setLoading(false);
8797
} catch (error) {
8898
chatbotReply = "Oops! It seems we couldn't retrieve the answer. Please try again later";
8999
setInputMessage('');
90-
simulateTypingEffect(chatbotReply);
100+
simulateTypingEffect({ reply: chatbotReply });
91101
setLoading(false);
92102
}
93103
};
@@ -160,6 +170,21 @@ export default function Chatbot(props: ChatbotProps) {
160170
</div>
161171
<div className='text-right align-bottom pt-3'>
162172
<Typography variant='body-small'>{chat.datetime}</Typography>
173+
{chat?.sources?.length ? (
174+
<div className={`flex ${chat.sources?.length > 1 ? 'flex-col' : 'flex-row justify-end'} gap-1`}>
175+
{chat.sources.map((link) => (
176+
<div className='text-right'>
177+
{link.startsWith('http') || link.startsWith('https') ? (
178+
<TextLink href={link} externalLink={true}>
179+
Source
180+
</TextLink>
181+
) : (
182+
<Typography variant='body-small'>{link}</Typography>
183+
)}
184+
</div>
185+
))}
186+
</div>
187+
) : null}
163188
</div>
164189
</Widget>
165190
</div>

0 commit comments

Comments
 (0)