@@ -27,16 +27,14 @@ def embed_knowledge_base(self):
2727 def normalize_query (self , query ):
2828 return query .lower ().strip ()
2929
30- def get_query_embedding (self , query , use_cpu = True ):
30+ def get_query_embedding (self , query , use_cpu = False ):
3131 normalized_query = self .normalize_query (query )
3232 query_embedding = self .model .encode ([normalized_query ], convert_to_tensor = True )
33- # Move the embeddings to the CPU to ensure compatibility with operations like cosine_similarity
3433 if use_cpu :
3534 query_embedding = query_embedding .cpu ()
3635 return query_embedding
3736
38- def get_doc_embeddings (self , use_cpu = True ):
39- # Move the embeddings to the CPU to ensure compatibility with operations like cosine_similarity
37+ def get_doc_embeddings (self , use_cpu = False ):
4038 if use_cpu :
4139 return self .doc_embeddings .cpu ()
4240 return self .doc_embeddings
@@ -64,9 +62,12 @@ def compute_document_scores(self, query_embedding, doc_embeddings, high_match_th
6462
6563 return result
6664
67- def retrieve (self , query , similarity_threshold = 0.7 , high_match_threshold = 0.8 , max_docs = 5 ):
68- query_embedding = self .get_query_embedding (query )
69- doc_embeddings = self .get_doc_embeddings ()
65+ def retrieve (self , query , similarity_threshold = 0.7 , high_match_threshold = 0.8 , max_docs = 5 , use_cpu = False ):
66+ # Note: Set use_cpu=True to run on CPU, which is useful for testing or environments without a GPU.
67+ # Set use_cpu=False to leverage GPU for better performance in production.
68+
69+ query_embedding = self .get_query_embedding (query , use_cpu )
70+ doc_embeddings = self .get_doc_embeddings (use_cpu )
7071
7172 doc_scores = self .compute_document_scores (query_embedding , doc_embeddings , high_match_threshold )
7273 retrieved_docs = self .get_top_docs (doc_scores , similarity_threshold , max_docs )
0 commit comments