Skip to content

Commit 0b5bb74

Browse files
authored
Merge pull request #34 from DefangLabs/linda-test-chatbot
Add tests, refactoring, and fallback response to chatbot
2 parents 6cde831 + e1947d4 commit 0b5bb74

File tree

7 files changed

+845
-52
lines changed

7 files changed

+845
-52
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
.env
12
__pycache__
23
sentence-transformers
3-
myenv/
44
.tmp/*
55
!.tmp/prebuild.sh

app/app.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def validate_pow(nonce, data, difficulty):
2828

2929
@app.route('/', methods=['GET', 'POST'])
3030
def index():
31-
return render_template('index.html')
31+
return render_template('index.html', debug=os.getenv('DEBUG'))
3232

3333
@app.route('/ask', methods=['POST'])
3434
def ask():
@@ -98,5 +98,15 @@ def trigger_rebuild():
9898
print(f"Error in /trigger-rebuild endpoint: {e}")
9999
return jsonify({"error": "Internal Server Error"}), 500
100100

101+
if os.getenv('DEBUG') == '1':
102+
@app.route('/ask/debug', methods=['POST'])
103+
def debug_context():
104+
data = request.get_json()
105+
query = data.get('query', '')
106+
if not query:
107+
return jsonify({"error": "Query is required"}), 400
108+
context = rag_system.get_context(query)
109+
return jsonify({"context": context})
110+
101111
if __name__ == '__main__':
102112
app.run(host='0.0.0.0', port=5050)

app/rag_system.py

Lines changed: 84 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,38 +26,92 @@ def embed_knowledge_base(self):
2626

2727
def normalize_query(self, query):
2828
return query.lower().strip()
29-
30-
def retrieve(self, query, similarity_threshold=0.7, high_match_threshold=0.8, max_docs=5):
29+
30+
def get_query_embedding(self, query, use_cpu=False):
3131
normalized_query = self.normalize_query(query)
3232
query_embedding = self.model.encode([normalized_query], convert_to_tensor=True)
33-
similarities = cosine_similarity(query_embedding, self.doc_embeddings)[0]
34-
relevance_scores = []
35-
36-
for i, doc in enumerate(self.knowledge_base):
33+
if use_cpu:
34+
query_embedding = query_embedding.cpu()
35+
return query_embedding
36+
37+
def get_doc_embeddings(self, use_cpu=False):
38+
if use_cpu:
39+
return self.doc_embeddings.cpu()
40+
return self.doc_embeddings
41+
42+
def compute_document_scores(self, query_embedding, doc_embeddings, high_match_threshold):
43+
text_similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
44+
about_similarities = []
45+
for doc in self.knowledge_base:
3746
about_similarity = cosine_similarity(query_embedding, self.model.encode([doc["about"]]))[0][0]
38-
text_similarity = similarities[i]
39-
40-
combined_score = (0.3 * about_similarity) + (0.7 * text_similarity)
41-
if about_similarity >= high_match_threshold or text_similarity >= high_match_threshold:
42-
combined_score = max(about_similarity, text_similarity)
43-
44-
relevance_scores.append((i, combined_score))
47+
about_similarities.append(about_similarity)
48+
49+
relevance_scores = self.compute_relevance_scores(text_similarities, about_similarities, high_match_threshold)
50+
51+
result = [
52+
{
53+
"index": i,
54+
"about": doc["about"],
55+
"text": doc["text"],
56+
"text_similarity": text_similarities[i],
57+
"about_similarity": about_similarities[i],
58+
"relevance_score": relevance_scores[i]
59+
}
60+
for i, doc in enumerate(self.knowledge_base)
61+
]
4562

46-
sorted_indices = sorted(relevance_scores, key=lambda x: x[1], reverse=True)
47-
top_indices = [i for i, score in sorted_indices[:max_docs] if score >= similarity_threshold]
63+
return result
4864

49-
retrieved_docs = [f'{self.knowledge_base[i]["about"]}. {self.knowledge_base[i]["text"]}' for i in top_indices]
65+
def retrieve(self, query, similarity_threshold=0.7, high_match_threshold=0.8, max_docs=5, use_cpu=False):
66+
# Note: Set use_cpu=True to run on CPU, which is useful for testing or environments without a GPU.
67+
# Set use_cpu=False to leverage GPU for better performance in production.
68+
69+
query_embedding = self.get_query_embedding(query, use_cpu)
70+
doc_embeddings = self.get_doc_embeddings(use_cpu)
5071

72+
doc_scores = self.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold)
73+
retrieved_docs = self.get_top_docs(doc_scores, similarity_threshold, max_docs)
74+
5175
if not retrieved_docs:
52-
max_index = np.argmax(similarities)
53-
retrieved_docs.append(f'{self.knowledge_base[max_index]["about"]}. {self.knowledge_base[max_index]["text"]}')
54-
55-
return "\n\n".join(retrieved_docs)
76+
retrieved_docs = self.get_fallback_doc()
77+
return retrieved_docs
78+
5679

80+
def compute_relevance_scores(self, text_similarities, about_similarities, high_match_threshold):
81+
relevance_scores = []
82+
for i, _ in enumerate(self.knowledge_base):
83+
about_similarity = about_similarities[i]
84+
text_similarity = text_similarities[i]
85+
# If either about or text similarity is above the high match threshold, prioritize it
86+
if about_similarity >= high_match_threshold or text_similarity >= high_match_threshold:
87+
combined_score = max(about_similarity, text_similarity)
88+
else:
89+
combined_score = (0.3 * about_similarity) + (0.7 * text_similarity)
90+
relevance_scores.append(combined_score)
91+
92+
return relevance_scores
93+
94+
def get_top_docs(self, doc_scores, similarity_threshold, max_docs):
95+
sorted_docs = sorted(doc_scores, key=lambda x: x["relevance_score"], reverse=True)
96+
# Filter and keep up to max_docs with relevance scores above the similarity threshold
97+
top_docs = [score for score in sorted_docs[:max_docs] if score["relevance_score"] >= similarity_threshold]
98+
return top_docs
99+
100+
def get_fallback_doc(self):
101+
return [
102+
{
103+
"about": "No Relevant Information Found",
104+
"text": (
105+
"I'm sorry, I couldn't find any relevant information for your query. "
106+
"Please try rephrasing your question or ask about a different topic. "
107+
"For further assistance, you can visit our official website or reach out to our support team."
108+
)
109+
}
110+
]
111+
57112
def answer_query_stream(self, query):
58113
try:
59-
normalized_query = self.normalize_query(query)
60-
context = self.retrieve(normalized_query)
114+
context = self.get_context(query)
61115

62116
self.conversation_history.append({"role": "user", "content": query})
63117

@@ -117,5 +171,13 @@ def rebuild_embeddings(self):
117171
self.doc_embeddings = self.embed_knowledge_base() # Rebuild the embeddings
118172
print("Embeddings have been rebuilt.")
119173

174+
def get_context(self, query):
175+
normalized_query = self.normalize_query(query)
176+
retrieved_docs = self.retrieve(normalized_query)
177+
retrieved_text = []
178+
for doc in retrieved_docs:
179+
retrieved_text.append(f'{doc["about"]}. {doc["text"]}')
180+
return "\n\n".join(retrieved_text)
181+
120182
# Instantiate the RAGSystem
121183
rag_system = RAGSystem()

app/templates/index.html

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
<head>
44
<meta charset="UTF-8">
55
<meta name="viewport" content="width=device-width, initial-scale=1.0">
6+
<meta name="debug" content="{{ debug }}">
67
<title>Ask Defang</title>
78
<script src="https://cdnjs.cloudflare.com/ajax/libs/marked/4.0.2/marked.min.js"></script>
89
<link rel="icon" href="{{ url_for('static', filename='images/favicon.ico') }}" type="image/x-icon">
@@ -259,7 +260,6 @@ <h2>Ask Defang</h2>
259260
const responseContainer = document.createElement('p');
260261
responseContainer.innerHTML = `<strong>Assistant:</strong> <span class="assistant-response"></span>`;
261262
chatBox.appendChild(responseContainer);
262-
const assistantResponse = responseContainer.querySelector('.assistant-response');
263263

264264
chatBox.scrollTop = chatBox.scrollHeight;
265265

@@ -268,44 +268,84 @@ <h2>Ask Defang</h2>
268268
loadingSpinner.style.display = 'inline-block';
269269
sendButton.disabled = true;
270270

271+
const debug = document.querySelector('meta[name=debug]').content;
272+
if (debug == '1') {
273+
askDebug(query)
274+
} else {
275+
ask(query, responseContainer)
276+
}
277+
}
278+
279+
function ask(query, responseContainer) {
280+
const assistantResponse = responseContainer.querySelector('.assistant-response');
271281
// Send query to server
272282
rateLimitingFetch('/ask', {
283+
method: 'POST',
284+
headers: {
285+
'Content-Type': 'application/json',
286+
'X-CSRFToken': '{{ csrf_token() }}'
287+
},
288+
body: JSON.stringify({ query: query }),
289+
})
290+
.then(response => {
291+
const reader = response.body.getReader();
292+
const decoder = new TextDecoder();
293+
let responseText = '';
294+
295+
function readStream() {
296+
reader.read().then(({ done, value }) => {
297+
if (done) {
298+
loadingSpinner.style.display = 'none';
299+
sendButton.disabled = false;
300+
return;
301+
}
302+
303+
const chunk = decoder.decode(value);
304+
responseText += chunk;
305+
assistantResponse.innerHTML = marked.parse(responseText);
306+
chatBox.scrollTop = chatBox.scrollHeight;
307+
308+
readStream();
309+
});
310+
}
311+
312+
readStream();
313+
})
314+
.catch(error => {
315+
console.error('Error:', error);
316+
loadingSpinner.style.display = 'none';
317+
sendButton.disabled = false;
318+
assistantResponse.textContent = 'Error: Failed to get response';
319+
chatBox.scrollTop = chatBox.scrollHeight;
320+
});
321+
}
322+
323+
function askDebug(query) {
324+
// fetch context for debugging
325+
rateLimitingFetch('/ask/debug', {
273326
method: 'POST',
274327
headers: {
275328
'Content-Type': 'application/json',
276-
'X-CSRFToken': '{{ csrf_token() }}'
329+
'X-CSRFToken': '{{ csrf_token() }}'
277330
},
278-
body: JSON.stringify({ query: query }),
331+
body: JSON.stringify({ query: query}),
279332
})
280-
.then(response => {
281-
const reader = response.body.getReader();
282-
const decoder = new TextDecoder();
283-
let responseText = '';
284-
285-
function readStream() {
286-
reader.read().then(({ done, value }) => {
287-
if (done) {
288-
loadingSpinner.style.display = 'none';
289-
sendButton.disabled = false;
290-
return;
291-
}
292-
293-
const chunk = decoder.decode(value);
294-
responseText += chunk;
295-
assistantResponse.innerHTML = marked.parse(responseText);
296-
chatBox.scrollTop = chatBox.scrollHeight;
297-
298-
readStream();
299-
});
333+
.then(response => response.json())
334+
.then(data => {
335+
if (data.error) {
336+
chatBox.innerHTML += `<p><strong>Debug Context:</strong> Error: ${data.error}</p>`;
337+
} else {
338+
chatBox.innerHTML += `<p><strong>Debug Context:</strong> ${data.context}</p>`;
300339
}
301-
302-
readStream();
340+
chatBox.scrollTop = chatBox.scrollHeight;
341+
loadingSpinner.style.display = 'none';
342+
sendButton.disabled = false;
303343
})
304344
.catch(error => {
305345
console.error('Error:', error);
306346
loadingSpinner.style.display = 'none';
307347
sendButton.disabled = false;
308-
assistantResponse.textContent = 'Error: Failed to get response';
348+
chatBox.innerHTML += '<p><strong>Debug Context:</strong> Error: Failed to get context</p>';
309349
chatBox.scrollTop = chatBox.scrollHeight;
310350
});
311351
}

0 commit comments

Comments
 (0)