@@ -21,19 +21,26 @@ def __init__(self, knowledge_base_path='./data/knowledge_base.json'):
2121
2222 # load existing embeddings if available
2323 logging .info ("Embedding knowledge base..." )
24- if os .path .exists ('./data/doc_embeddings.npy' ):
24+
25+ if os .path .exists ('./data/doc_about_embeddings.npy' ) and os .path .exists ('./data/doc_embeddings.npy' ):
26+ self .doc_about_embeddings = np .load ('./data/doc_about_embeddings.npy' )
27+ logging .info ("Loaded existing about document about embeddings from disk." )
2528 self .doc_embeddings = np .load ('./data/doc_embeddings.npy' )
2629 logging .info ("Loaded existing document embeddings from disk." )
2730 else :
2831 self .rebuild_embeddings ()
32+
2933 logging .info ("Knowledge base embeddings created" )
3034 self .conversation_history = []
3135
3236 def rebuild_embeddings (self ):
3337 logging .info ("No existing document embeddings found, creating new embeddings." )
3438 self .doc_embeddings = self .embed_knowledge_base ()
39+ self .doc_about_embeddings = self .embed_knowledge_base_about ()
3540 # cache doc_embeddings to disk
3641 np .save ('./data/doc_embeddings.npy' , self .doc_embeddings .cpu ().numpy ())
42+ np .save ('./data/doc_about_embeddings.npy' , self .doc_about_embeddings .cpu ().numpy ())
43+
3744
3845 def load_knowledge_base (self ):
3946 with open (self .knowledge_base_path , 'r' ) as kb_file :
@@ -43,6 +50,9 @@ def embed_knowledge_base(self):
4350 docs = [f'{ doc ["about" ]} . { doc ["text" ]} ' for doc in self .knowledge_base ]
4451 return self .model .encode (docs , convert_to_tensor = True )
4552
53+ def embed_knowledge_base_about (self ):
54+ return self .model .encode ([doc ["about" ] for doc in self .knowledge_base ], convert_to_tensor = True )
55+
4656 def normalize_query (self , query ):
4757 return query .lower ().strip ()
4858
@@ -55,13 +65,12 @@ def get_query_embedding(self, query):
5565 def get_doc_embeddings (self ):
5666 return self .doc_embeddings
5767
58- def compute_document_scores (self , query_embedding , doc_embeddings , high_match_threshold ):
59- text_similarities = cosine_similarity (query_embedding , doc_embeddings )[0 ]
60- about_similarities = []
61- for doc in self .knowledge_base :
62- about_similarity = cosine_similarity (query_embedding , self .model .encode ([doc ["about" ]]))[0 ][0 ]
63- about_similarities .append (about_similarity )
68+ def get_doc_about_embeddings (self ):
69+ return self .doc_about_embeddings
6470
71+ def compute_document_scores (self , query_embedding , doc_embeddings , doc_about_embeddings , high_match_threshold ):
72+ text_similarities = cosine_similarity (query_embedding , doc_embeddings )[0 ]
73+ about_similarities = cosine_similarity (query_embedding , doc_about_embeddings )[0 ]
6574 relevance_scores = self .compute_relevance_scores (text_similarities , about_similarities , high_match_threshold )
6675
6776 result = [
@@ -82,8 +91,9 @@ def compute_document_scores(self, query_embedding, doc_embeddings, high_match_th
8291 def retrieve (self , query , similarity_threshold = 0.4 , high_match_threshold = 0.8 , max_docs = 5 ):
8392 query_embedding = self .get_query_embedding (query )
8493 doc_embeddings = self .get_doc_embeddings ()
94+ doc_about_embeddings = self .get_doc_about_embeddings ()
8595
86- doc_scores = self .compute_document_scores (query_embedding , doc_embeddings , high_match_threshold )
96+ doc_scores = self .compute_document_scores (query_embedding , doc_embeddings , doc_about_embeddings , high_match_threshold )
8797 retrieved_docs = self .get_top_docs (doc_scores , similarity_threshold , max_docs )
8898
8999 if not retrieved_docs :
0 commit comments