Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions app/rag_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,26 @@ def __init__(self, knowledge_base_path='./data/knowledge_base.json'):

# load existing embeddings if available
logging.info("Embedding knowledge base...")
if os.path.exists('./data/doc_embeddings.npy'):

if os.path.exists('./data/doc_about_embeddings.npy') and os.path.exists('./data/doc_embeddings.npy'):
self.doc_about_embeddings = np.load('./data/doc_about_embeddings.npy')
logging.info("Loaded existing about document about embeddings from disk.")
self.doc_embeddings = np.load('./data/doc_embeddings.npy')
logging.info("Loaded existing document embeddings from disk.")
else:
self.rebuild_embeddings()

logging.info("Knowledge base embeddings created")
self.conversation_history = []

def rebuild_embeddings(self):
logging.info("No existing document embeddings found, creating new embeddings.")
self.doc_embeddings = self.embed_knowledge_base()
self.doc_about_embeddings = self.embed_knowledge_base_about()
# cache doc_embeddings to disk
np.save('./data/doc_embeddings.npy', self.doc_embeddings.cpu().numpy())
np.save('./data/doc_about_embeddings.npy', self.doc_about_embeddings.cpu().numpy())


def load_knowledge_base(self):
with open(self.knowledge_base_path, 'r') as kb_file:
Expand All @@ -43,6 +50,9 @@ def embed_knowledge_base(self):
docs = [f'{doc["about"]}. {doc["text"]}' for doc in self.knowledge_base]
return self.model.encode(docs, convert_to_tensor=True)

def embed_knowledge_base_about(self):
return self.model.encode([doc["about"] for doc in self.knowledge_base], convert_to_tensor=True)

def normalize_query(self, query):
return query.lower().strip()

Expand All @@ -55,13 +65,12 @@ def get_query_embedding(self, query):
def get_doc_embeddings(self):
return self.doc_embeddings

def compute_document_scores(self, query_embedding, doc_embeddings, high_match_threshold):
text_similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
about_similarities = []
for doc in self.knowledge_base:
about_similarity = cosine_similarity(query_embedding, self.model.encode([doc["about"]]))[0][0]
about_similarities.append(about_similarity)
def get_doc_about_embeddings(self):
return self.doc_about_embeddings

def compute_document_scores(self, query_embedding, doc_embeddings, doc_about_embeddings, high_match_threshold):
text_similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
about_similarities = cosine_similarity(query_embedding, doc_about_embeddings)[0]
relevance_scores = self.compute_relevance_scores(text_similarities, about_similarities, high_match_threshold)

result = [
Expand All @@ -82,8 +91,9 @@ def compute_document_scores(self, query_embedding, doc_embeddings, high_match_th
def retrieve(self, query, similarity_threshold=0.4, high_match_threshold=0.8, max_docs=5):
query_embedding = self.get_query_embedding(query)
doc_embeddings = self.get_doc_embeddings()
doc_about_embeddings = self.get_doc_about_embeddings()

doc_scores = self.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold)
doc_scores = self.compute_document_scores(query_embedding, doc_embeddings, doc_about_embeddings, high_match_threshold)
retrieved_docs = self.get_top_docs(doc_scores, similarity_threshold, max_docs)

if not retrieved_docs:
Expand Down
3 changes: 2 additions & 1 deletion app/test_rag_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ def test_compute_document_scores(self):
query = "Does Defang have an MCP sample?"
query_embedding = self.rag_system.get_query_embedding(query)
doc_embeddings = self.rag_system.get_doc_embeddings()
doc_about_embeddings = self.rag_system.doc_about_embeddings()

# call function and get results
result = self.rag_system.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold=0.8)
result = self.rag_system.compute_document_scores(query_embedding, doc_embeddings, doc_about_embeddings, high_match_threshold=0.8)
# sort the result by relevance score in descending order
result = sorted(result, key=lambda x: x["relevance_score"], reverse=True)

Expand Down