@@ -21,19 +21,26 @@ def __init__(self, knowledge_base_path='./data/knowledge_base.json'):
21
21
22
22
# load existing embeddings if available
23
23
logging .info ("Embedding knowledge base..." )
24
- if os .path .exists ('./data/doc_embeddings.npy' ):
24
+
25
+ if os .path .exists ('./data/doc_about_embeddings.npy' ) and os .path .exists ('./data/doc_embeddings.npy' ):
26
+ self .doc_about_embeddings = np .load ('./data/doc_about_embeddings.npy' )
27
+ logging .info ("Loaded existing about document about embeddings from disk." )
25
28
self .doc_embeddings = np .load ('./data/doc_embeddings.npy' )
26
29
logging .info ("Loaded existing document embeddings from disk." )
27
30
else :
28
31
self .rebuild_embeddings ()
32
+
29
33
logging .info ("Knowledge base embeddings created" )
30
34
self .conversation_history = []
31
35
32
36
def rebuild_embeddings (self ):
33
37
logging .info ("No existing document embeddings found, creating new embeddings." )
34
38
self .doc_embeddings = self .embed_knowledge_base ()
39
+ self .doc_about_embeddings = self .embed_knowledge_base_about ()
35
40
# cache doc_embeddings to disk
36
41
np .save ('./data/doc_embeddings.npy' , self .doc_embeddings .cpu ().numpy ())
42
+ np .save ('./data/doc_about_embeddings.npy' , self .doc_about_embeddings .cpu ().numpy ())
43
+
37
44
38
45
def load_knowledge_base (self ):
39
46
with open (self .knowledge_base_path , 'r' ) as kb_file :
@@ -43,6 +50,9 @@ def embed_knowledge_base(self):
43
50
docs = [f'{ doc ["about" ]} . { doc ["text" ]} ' for doc in self .knowledge_base ]
44
51
return self .model .encode (docs , convert_to_tensor = True )
45
52
53
+ def embed_knowledge_base_about (self ):
54
+ return self .model .encode ([doc ["about" ] for doc in self .knowledge_base ], convert_to_tensor = True )
55
+
46
56
def normalize_query (self , query ):
47
57
return query .lower ().strip ()
48
58
@@ -55,13 +65,12 @@ def get_query_embedding(self, query):
55
65
def get_doc_embeddings (self ):
56
66
return self .doc_embeddings
57
67
58
- def compute_document_scores (self , query_embedding , doc_embeddings , high_match_threshold ):
59
- text_similarities = cosine_similarity (query_embedding , doc_embeddings )[0 ]
60
- about_similarities = []
61
- for doc in self .knowledge_base :
62
- about_similarity = cosine_similarity (query_embedding , self .model .encode ([doc ["about" ]]))[0 ][0 ]
63
- about_similarities .append (about_similarity )
68
+ def get_doc_about_embeddings (self ):
69
+ return self .doc_about_embeddings
64
70
71
+ def compute_document_scores (self , query_embedding , doc_embeddings , doc_about_embeddings , high_match_threshold ):
72
+ text_similarities = cosine_similarity (query_embedding , doc_embeddings )[0 ]
73
+ about_similarities = cosine_similarity (query_embedding , doc_about_embeddings )[0 ]
65
74
relevance_scores = self .compute_relevance_scores (text_similarities , about_similarities , high_match_threshold )
66
75
67
76
result = [
@@ -82,8 +91,9 @@ def compute_document_scores(self, query_embedding, doc_embeddings, high_match_th
82
91
def retrieve (self , query , similarity_threshold = 0.4 , high_match_threshold = 0.8 , max_docs = 5 ):
83
92
query_embedding = self .get_query_embedding (query )
84
93
doc_embeddings = self .get_doc_embeddings ()
94
+ doc_about_embeddings = self .get_doc_about_embeddings ()
85
95
86
- doc_scores = self .compute_document_scores (query_embedding , doc_embeddings , high_match_threshold )
96
+ doc_scores = self .compute_document_scores (query_embedding , doc_embeddings , doc_about_embeddings , high_match_threshold )
87
97
retrieved_docs = self .get_top_docs (doc_scores , similarity_threshold , max_docs )
88
98
89
99
if not retrieved_docs :
0 commit comments