diff --git a/app/rag_system.py b/app/rag_system.py index 32f2900..a272624 100644 --- a/app/rag_system.py +++ b/app/rag_system.py @@ -21,19 +21,26 @@ def __init__(self, knowledge_base_path='./data/knowledge_base.json'): # load existing embeddings if available logging.info("Embedding knowledge base...") - if os.path.exists('./data/doc_embeddings.npy'): + + if os.path.exists('./data/doc_about_embeddings.npy') and os.path.exists('./data/doc_embeddings.npy'): + self.doc_about_embeddings = np.load('./data/doc_about_embeddings.npy') + logging.info("Loaded existing about document about embeddings from disk.") self.doc_embeddings = np.load('./data/doc_embeddings.npy') logging.info("Loaded existing document embeddings from disk.") else: self.rebuild_embeddings() + logging.info("Knowledge base embeddings created") self.conversation_history = [] def rebuild_embeddings(self): logging.info("No existing document embeddings found, creating new embeddings.") self.doc_embeddings = self.embed_knowledge_base() + self.doc_about_embeddings = self.embed_knowledge_base_about() # cache doc_embeddings to disk np.save('./data/doc_embeddings.npy', self.doc_embeddings.cpu().numpy()) + np.save('./data/doc_about_embeddings.npy', self.doc_about_embeddings.cpu().numpy()) + def load_knowledge_base(self): with open(self.knowledge_base_path, 'r') as kb_file: @@ -43,6 +50,9 @@ def embed_knowledge_base(self): docs = [f'{doc["about"]}. {doc["text"]}' for doc in self.knowledge_base] return self.model.encode(docs, convert_to_tensor=True) + def embed_knowledge_base_about(self): + return self.model.encode([doc["about"] for doc in self.knowledge_base], convert_to_tensor=True) + def normalize_query(self, query): return query.lower().strip() @@ -55,13 +65,12 @@ def get_query_embedding(self, query): def get_doc_embeddings(self): return self.doc_embeddings - def compute_document_scores(self, query_embedding, doc_embeddings, high_match_threshold): - text_similarities = cosine_similarity(query_embedding, doc_embeddings)[0] - about_similarities = [] - for doc in self.knowledge_base: - about_similarity = cosine_similarity(query_embedding, self.model.encode([doc["about"]]))[0][0] - about_similarities.append(about_similarity) + def get_doc_about_embeddings(self): + return self.doc_about_embeddings + def compute_document_scores(self, query_embedding, doc_embeddings, doc_about_embeddings, high_match_threshold): + text_similarities = cosine_similarity(query_embedding, doc_embeddings)[0] + about_similarities = cosine_similarity(query_embedding, doc_about_embeddings)[0] relevance_scores = self.compute_relevance_scores(text_similarities, about_similarities, high_match_threshold) result = [ @@ -82,8 +91,9 @@ def compute_document_scores(self, query_embedding, doc_embeddings, high_match_th def retrieve(self, query, similarity_threshold=0.4, high_match_threshold=0.8, max_docs=5): query_embedding = self.get_query_embedding(query) doc_embeddings = self.get_doc_embeddings() + doc_about_embeddings = self.get_doc_about_embeddings() - doc_scores = self.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold) + doc_scores = self.compute_document_scores(query_embedding, doc_embeddings, doc_about_embeddings, high_match_threshold) retrieved_docs = self.get_top_docs(doc_scores, similarity_threshold, max_docs) if not retrieved_docs: diff --git a/app/test_rag_system.py b/app/test_rag_system.py index 153f60a..15f15bf 100644 --- a/app/test_rag_system.py +++ b/app/test_rag_system.py @@ -80,9 +80,10 @@ def test_compute_document_scores(self): query = "Does Defang have an MCP sample?" query_embedding = self.rag_system.get_query_embedding(query) doc_embeddings = self.rag_system.get_doc_embeddings() + doc_about_embeddings = self.rag_system.doc_about_embeddings() # call function and get results - result = self.rag_system.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold=0.8) + result = self.rag_system.compute_document_scores(query_embedding, doc_embeddings, doc_about_embeddings, high_match_threshold=0.8) # sort the result by relevance score in descending order result = sorted(result, key=lambda x: x["relevance_score"], reverse=True)