Skip to content

Commit b87d853

Browse files
authored
Merge pull request #86 from DefangLabs/eric/cache-doc-about-embedding
add about embedding cache
2 parents 7eb2f3f + 469ce8d commit b87d853

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

app/rag_system.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,26 @@ def __init__(self, knowledge_base_path='./data/knowledge_base.json'):
2121

2222
# load existing embeddings if available
2323
logging.info("Embedding knowledge base...")
24-
if os.path.exists('./data/doc_embeddings.npy'):
24+
25+
if os.path.exists('./data/doc_about_embeddings.npy') and os.path.exists('./data/doc_embeddings.npy'):
26+
self.doc_about_embeddings = np.load('./data/doc_about_embeddings.npy')
27+
logging.info("Loaded existing about document about embeddings from disk.")
2528
self.doc_embeddings = np.load('./data/doc_embeddings.npy')
2629
logging.info("Loaded existing document embeddings from disk.")
2730
else:
2831
self.rebuild_embeddings()
32+
2933
logging.info("Knowledge base embeddings created")
3034
self.conversation_history = []
3135

3236
def rebuild_embeddings(self):
3337
logging.info("No existing document embeddings found, creating new embeddings.")
3438
self.doc_embeddings = self.embed_knowledge_base()
39+
self.doc_about_embeddings = self.embed_knowledge_base_about()
3540
# cache doc_embeddings to disk
3641
np.save('./data/doc_embeddings.npy', self.doc_embeddings.cpu().numpy())
42+
np.save('./data/doc_about_embeddings.npy', self.doc_about_embeddings.cpu().numpy())
43+
3744

3845
def load_knowledge_base(self):
3946
with open(self.knowledge_base_path, 'r') as kb_file:
@@ -43,6 +50,9 @@ def embed_knowledge_base(self):
4350
docs = [f'{doc["about"]}. {doc["text"]}' for doc in self.knowledge_base]
4451
return self.model.encode(docs, convert_to_tensor=True)
4552

53+
def embed_knowledge_base_about(self):
54+
return self.model.encode([doc["about"] for doc in self.knowledge_base], convert_to_tensor=True)
55+
4656
def normalize_query(self, query):
4757
return query.lower().strip()
4858

@@ -55,13 +65,12 @@ def get_query_embedding(self, query):
5565
def get_doc_embeddings(self):
5666
return self.doc_embeddings
5767

58-
def compute_document_scores(self, query_embedding, doc_embeddings, high_match_threshold):
59-
text_similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
60-
about_similarities = []
61-
for doc in self.knowledge_base:
62-
about_similarity = cosine_similarity(query_embedding, self.model.encode([doc["about"]]))[0][0]
63-
about_similarities.append(about_similarity)
68+
def get_doc_about_embeddings(self):
69+
return self.doc_about_embeddings
6470

71+
def compute_document_scores(self, query_embedding, doc_embeddings, doc_about_embeddings, high_match_threshold):
72+
text_similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
73+
about_similarities = cosine_similarity(query_embedding, doc_about_embeddings)[0]
6574
relevance_scores = self.compute_relevance_scores(text_similarities, about_similarities, high_match_threshold)
6675

6776
result = [
@@ -82,8 +91,9 @@ def compute_document_scores(self, query_embedding, doc_embeddings, high_match_th
8291
def retrieve(self, query, similarity_threshold=0.4, high_match_threshold=0.8, max_docs=5):
8392
query_embedding = self.get_query_embedding(query)
8493
doc_embeddings = self.get_doc_embeddings()
94+
doc_about_embeddings = self.get_doc_about_embeddings()
8595

86-
doc_scores = self.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold)
96+
doc_scores = self.compute_document_scores(query_embedding, doc_embeddings, doc_about_embeddings, high_match_threshold)
8797
retrieved_docs = self.get_top_docs(doc_scores, similarity_threshold, max_docs)
8898

8999
if not retrieved_docs:

app/test_rag_system.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,10 @@ def test_compute_document_scores(self):
8080
query = "Does Defang have an MCP sample?"
8181
query_embedding = self.rag_system.get_query_embedding(query)
8282
doc_embeddings = self.rag_system.get_doc_embeddings()
83+
doc_about_embeddings = self.rag_system.doc_about_embeddings()
8384

8485
# call function and get results
85-
result = self.rag_system.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold=0.8)
86+
result = self.rag_system.compute_document_scores(query_embedding, doc_embeddings, doc_about_embeddings, high_match_threshold=0.8)
8687
# sort the result by relevance score in descending order
8788
result = sorted(result, key=lambda x: x["relevance_score"], reverse=True)
8889

0 commit comments

Comments
 (0)