Skip to content

Commit 601e548

Browse files
Merge branch 'main' into jordan/devex2
2 parents 00fa1e3 + 0b5bb74 commit 601e548

File tree

6 files changed

+842
-49
lines changed

6 files changed

+842
-49
lines changed

app/app.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def validate_pow(nonce, data, difficulty):
2929

3030
@app.route('/', methods=['GET', 'POST'])
3131
def index():
32-
return render_template('index.html')
32+
return render_template('index.html', debug=os.getenv('DEBUG'))
3333

3434
@app.route('/ask', methods=['POST'])
3535
def ask():
@@ -99,5 +99,15 @@ def trigger_rebuild():
9999
print(f"Error in /trigger-rebuild endpoint: {e}")
100100
return jsonify({"error": "Internal Server Error"}), 500
101101

102+
if os.getenv('DEBUG') == '1':
103+
@app.route('/ask/debug', methods=['POST'])
104+
def debug_context():
105+
data = request.get_json()
106+
query = data.get('query', '')
107+
if not query:
108+
return jsonify({"error": "Query is required"}), 400
109+
context = rag_system.get_context(query)
110+
return jsonify({"context": context})
111+
102112
if __name__ == '__main__':
103113
app.run(host='0.0.0.0', port=5050)

app/rag_system.py

Lines changed: 84 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,38 +27,92 @@ def embed_knowledge_base(self):
2727

2828
def normalize_query(self, query):
2929
return query.lower().strip()
30-
31-
def retrieve(self, query, similarity_threshold=0.7, high_match_threshold=0.8, max_docs=5):
30+
31+
def get_query_embedding(self, query, use_cpu=False):
3232
normalized_query = self.normalize_query(query)
3333
query_embedding = self.model.encode([normalized_query], convert_to_tensor=True)
34-
similarities = cosine_similarity(query_embedding, self.doc_embeddings)[0]
35-
relevance_scores = []
36-
37-
for i, doc in enumerate(self.knowledge_base):
34+
if use_cpu:
35+
query_embedding = query_embedding.cpu()
36+
return query_embedding
37+
38+
def get_doc_embeddings(self, use_cpu=False):
39+
if use_cpu:
40+
return self.doc_embeddings.cpu()
41+
return self.doc_embeddings
42+
43+
def compute_document_scores(self, query_embedding, doc_embeddings, high_match_threshold):
44+
text_similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
45+
about_similarities = []
46+
for doc in self.knowledge_base:
3847
about_similarity = cosine_similarity(query_embedding, self.model.encode([doc["about"]]))[0][0]
39-
text_similarity = similarities[i]
40-
41-
combined_score = (0.3 * about_similarity) + (0.7 * text_similarity)
42-
if about_similarity >= high_match_threshold or text_similarity >= high_match_threshold:
43-
combined_score = max(about_similarity, text_similarity)
44-
45-
relevance_scores.append((i, combined_score))
48+
about_similarities.append(about_similarity)
49+
50+
relevance_scores = self.compute_relevance_scores(text_similarities, about_similarities, high_match_threshold)
51+
52+
result = [
53+
{
54+
"index": i,
55+
"about": doc["about"],
56+
"text": doc["text"],
57+
"text_similarity": text_similarities[i],
58+
"about_similarity": about_similarities[i],
59+
"relevance_score": relevance_scores[i]
60+
}
61+
for i, doc in enumerate(self.knowledge_base)
62+
]
4663

47-
sorted_indices = sorted(relevance_scores, key=lambda x: x[1], reverse=True)
48-
top_indices = [i for i, score in sorted_indices[:max_docs] if score >= similarity_threshold]
64+
return result
4965

50-
retrieved_docs = [f'{self.knowledge_base[i]["about"]}. {self.knowledge_base[i]["text"]}' for i in top_indices]
66+
def retrieve(self, query, similarity_threshold=0.7, high_match_threshold=0.8, max_docs=5, use_cpu=False):
67+
# Note: Set use_cpu=True to run on CPU, which is useful for testing or environments without a GPU.
68+
# Set use_cpu=False to leverage GPU for better performance in production.
69+
70+
query_embedding = self.get_query_embedding(query, use_cpu)
71+
doc_embeddings = self.get_doc_embeddings(use_cpu)
5172

73+
doc_scores = self.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold)
74+
retrieved_docs = self.get_top_docs(doc_scores, similarity_threshold, max_docs)
75+
5276
if not retrieved_docs:
53-
max_index = np.argmax(similarities)
54-
retrieved_docs.append(f'{self.knowledge_base[max_index]["about"]}. {self.knowledge_base[max_index]["text"]}')
55-
56-
return "\n\n".join(retrieved_docs)
77+
retrieved_docs = self.get_fallback_doc()
78+
return retrieved_docs
79+
5780

81+
def compute_relevance_scores(self, text_similarities, about_similarities, high_match_threshold):
82+
relevance_scores = []
83+
for i, _ in enumerate(self.knowledge_base):
84+
about_similarity = about_similarities[i]
85+
text_similarity = text_similarities[i]
86+
# If either about or text similarity is above the high match threshold, prioritize it
87+
if about_similarity >= high_match_threshold or text_similarity >= high_match_threshold:
88+
combined_score = max(about_similarity, text_similarity)
89+
else:
90+
combined_score = (0.3 * about_similarity) + (0.7 * text_similarity)
91+
relevance_scores.append(combined_score)
92+
93+
return relevance_scores
94+
95+
def get_top_docs(self, doc_scores, similarity_threshold, max_docs):
96+
sorted_docs = sorted(doc_scores, key=lambda x: x["relevance_score"], reverse=True)
97+
# Filter and keep up to max_docs with relevance scores above the similarity threshold
98+
top_docs = [score for score in sorted_docs[:max_docs] if score["relevance_score"] >= similarity_threshold]
99+
return top_docs
100+
101+
def get_fallback_doc(self):
102+
return [
103+
{
104+
"about": "No Relevant Information Found",
105+
"text": (
106+
"I'm sorry, I couldn't find any relevant information for your query. "
107+
"Please try rephrasing your question or ask about a different topic. "
108+
"For further assistance, you can visit our official website or reach out to our support team."
109+
)
110+
}
111+
]
112+
58113
def answer_query_stream(self, query):
59114
try:
60-
normalized_query = self.normalize_query(query)
61-
context = self.retrieve(normalized_query)
115+
context = self.get_context(query)
62116

63117
self.conversation_history.append({"role": "user", "content": query})
64118

@@ -118,5 +172,13 @@ def rebuild_embeddings(self):
118172
self.doc_embeddings = self.embed_knowledge_base() # Rebuild the embeddings
119173
print("Embeddings have been rebuilt.")
120174

175+
def get_context(self, query):
176+
normalized_query = self.normalize_query(query)
177+
retrieved_docs = self.retrieve(normalized_query)
178+
retrieved_text = []
179+
for doc in retrieved_docs:
180+
retrieved_text.append(f'{doc["about"]}. {doc["text"]}')
181+
return "\n\n".join(retrieved_text)
182+
121183
# Instantiate the RAGSystem
122184
rag_system = RAGSystem()

app/templates/index.html

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
<head>
44
<meta charset="UTF-8">
55
<meta name="viewport" content="width=device-width, initial-scale=1.0">
6+
<meta name="debug" content="{{ debug }}">
67
<title>Ask Defang</title>
78
<script src="https://cdnjs.cloudflare.com/ajax/libs/marked/4.0.2/marked.min.js"></script>
89
<link rel="icon" href="{{ url_for('static', filename='images/favicon.ico') }}" type="image/x-icon">
@@ -259,7 +260,6 @@ <h2>Ask Defang</h2>
259260
const responseContainer = document.createElement('p');
260261
responseContainer.innerHTML = `<strong>Assistant:</strong> <span class="assistant-response"></span>`;
261262
chatBox.appendChild(responseContainer);
262-
const assistantResponse = responseContainer.querySelector('.assistant-response');
263263

264264
chatBox.scrollTop = chatBox.scrollHeight;
265265

@@ -268,44 +268,84 @@ <h2>Ask Defang</h2>
268268
loadingSpinner.style.display = 'inline-block';
269269
sendButton.disabled = true;
270270

271+
const debug = document.querySelector('meta[name=debug]').content;
272+
if (debug == '1') {
273+
askDebug(query)
274+
} else {
275+
ask(query, responseContainer)
276+
}
277+
}
278+
279+
function ask(query, responseContainer) {
280+
const assistantResponse = responseContainer.querySelector('.assistant-response');
271281
// Send query to server
272282
rateLimitingFetch('/ask', {
283+
method: 'POST',
284+
headers: {
285+
'Content-Type': 'application/json',
286+
'X-CSRFToken': '{{ csrf_token() }}'
287+
},
288+
body: JSON.stringify({ query: query }),
289+
})
290+
.then(response => {
291+
const reader = response.body.getReader();
292+
const decoder = new TextDecoder();
293+
let responseText = '';
294+
295+
function readStream() {
296+
reader.read().then(({ done, value }) => {
297+
if (done) {
298+
loadingSpinner.style.display = 'none';
299+
sendButton.disabled = false;
300+
return;
301+
}
302+
303+
const chunk = decoder.decode(value);
304+
responseText += chunk;
305+
assistantResponse.innerHTML = marked.parse(responseText);
306+
chatBox.scrollTop = chatBox.scrollHeight;
307+
308+
readStream();
309+
});
310+
}
311+
312+
readStream();
313+
})
314+
.catch(error => {
315+
console.error('Error:', error);
316+
loadingSpinner.style.display = 'none';
317+
sendButton.disabled = false;
318+
assistantResponse.textContent = 'Error: Failed to get response';
319+
chatBox.scrollTop = chatBox.scrollHeight;
320+
});
321+
}
322+
323+
function askDebug(query) {
324+
// fetch context for debugging
325+
rateLimitingFetch('/ask/debug', {
273326
method: 'POST',
274327
headers: {
275328
'Content-Type': 'application/json',
276-
'X-CSRFToken': '{{ csrf_token() }}'
329+
'X-CSRFToken': '{{ csrf_token() }}'
277330
},
278-
body: JSON.stringify({ query: query }),
331+
body: JSON.stringify({ query: query}),
279332
})
280-
.then(response => {
281-
const reader = response.body.getReader();
282-
const decoder = new TextDecoder();
283-
let responseText = '';
284-
285-
function readStream() {
286-
reader.read().then(({ done, value }) => {
287-
if (done) {
288-
loadingSpinner.style.display = 'none';
289-
sendButton.disabled = false;
290-
return;
291-
}
292-
293-
const chunk = decoder.decode(value);
294-
responseText += chunk;
295-
assistantResponse.innerHTML = marked.parse(responseText);
296-
chatBox.scrollTop = chatBox.scrollHeight;
297-
298-
readStream();
299-
});
333+
.then(response => response.json())
334+
.then(data => {
335+
if (data.error) {
336+
chatBox.innerHTML += `<p><strong>Debug Context:</strong> Error: ${data.error}</p>`;
337+
} else {
338+
chatBox.innerHTML += `<p><strong>Debug Context:</strong> ${data.context}</p>`;
300339
}
301-
302-
readStream();
340+
chatBox.scrollTop = chatBox.scrollHeight;
341+
loadingSpinner.style.display = 'none';
342+
sendButton.disabled = false;
303343
})
304344
.catch(error => {
305345
console.error('Error:', error);
306346
loadingSpinner.style.display = 'none';
307347
sendButton.disabled = false;
308-
assistantResponse.textContent = 'Error: Failed to get response';
348+
chatBox.innerHTML += '<p><strong>Debug Context:</strong> Error: Failed to get context</p>';
309349
chatBox.scrollTop = chatBox.scrollHeight;
310350
});
311351
}

0 commit comments

Comments
 (0)