@@ -27,14 +27,16 @@ def embed_knowledge_base(self):
2727 def normalize_query (self , query ):
2828 return query .lower ().strip ()
2929
30- def get_query_embedding (self , query , use_cpu = False ):
30+ def get_query_embedding (self , query , use_cpu = True ):
3131 normalized_query = self .normalize_query (query )
3232 query_embedding = self .model .encode ([normalized_query ], convert_to_tensor = True )
33+ # Move the embeddings to the CPU to ensure compatibility with operations like cosine_similarity
3334 if use_cpu :
3435 query_embedding = query_embedding .cpu ()
3536 return query_embedding
3637
37- def get_doc_embeddings (self , use_cpu = False ):
38+ def get_doc_embeddings (self , use_cpu = True ):
39+ # Move the embeddings to the CPU to ensure compatibility with operations like cosine_similarity
3840 if use_cpu :
3941 return self .doc_embeddings .cpu ()
4042 return self .doc_embeddings
@@ -62,12 +64,9 @@ def compute_document_scores(self, query_embedding, doc_embeddings, high_match_th
6264
6365 return result
6466
65- def retrieve (self , query , similarity_threshold = 0.7 , high_match_threshold = 0.8 , max_docs = 5 , use_cpu = False ):
66- # Note: Set use_cpu=True to run on CPU, which is useful for testing or environments without a GPU.
67- # Set use_cpu=False to leverage GPU for better performance in production.
68-
69- query_embedding = self .get_query_embedding (query , use_cpu )
70- doc_embeddings = self .get_doc_embeddings (use_cpu )
67+ def retrieve (self , query , similarity_threshold = 0.7 , high_match_threshold = 0.8 , max_docs = 5 ):
68+ query_embedding = self .get_query_embedding (query )
69+ doc_embeddings = self .get_doc_embeddings ()
7170
7271 doc_scores = self .compute_document_scores (query_embedding , doc_embeddings , high_match_threshold )
7372 retrieved_docs = self .get_top_docs (doc_scores , similarity_threshold , max_docs )
0 commit comments