Skip to content

Commit 020e1b3

Browse files
committed
make cpu tensors the default for both testing and production
1 parent 8d4ca04 commit 020e1b3

File tree

2 files changed

+11
-14
lines changed

2 files changed

+11
-14
lines changed

app/rag_system.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,16 @@ def embed_knowledge_base(self):
2727
def normalize_query(self, query):
2828
return query.lower().strip()
2929

30-
def get_query_embedding(self, query, use_cpu=False):
30+
def get_query_embedding(self, query, use_cpu=True):
3131
normalized_query = self.normalize_query(query)
3232
query_embedding = self.model.encode([normalized_query], convert_to_tensor=True)
33+
# Move the embeddings to the CPU to ensure compatibility with operations like cosine_similarity
3334
if use_cpu:
3435
query_embedding = query_embedding.cpu()
3536
return query_embedding
3637

37-
def get_doc_embeddings(self, use_cpu=False):
38+
def get_doc_embeddings(self, use_cpu=True):
39+
# Move the embeddings to the CPU to ensure compatibility with operations like cosine_similarity
3840
if use_cpu:
3941
return self.doc_embeddings.cpu()
4042
return self.doc_embeddings
@@ -62,12 +64,9 @@ def compute_document_scores(self, query_embedding, doc_embeddings, high_match_th
6264

6365
return result
6466

65-
def retrieve(self, query, similarity_threshold=0.7, high_match_threshold=0.8, max_docs=5, use_cpu=False):
66-
# Note: Set use_cpu=True to run on CPU, which is useful for testing or environments without a GPU.
67-
# Set use_cpu=False to leverage GPU for better performance in production.
68-
69-
query_embedding = self.get_query_embedding(query, use_cpu)
70-
doc_embeddings = self.get_doc_embeddings(use_cpu)
67+
def retrieve(self, query, similarity_threshold=0.7, high_match_threshold=0.8, max_docs=5):
68+
query_embedding = self.get_query_embedding(query)
69+
doc_embeddings = self.get_doc_embeddings()
7170

7271
doc_scores = self.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold)
7372
retrieved_docs = self.get_top_docs(doc_scores, similarity_threshold, max_docs)

app/test_rag_system.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def test_get_doc_embeddings(self):
5353
def test_retrieve_fallback(self):
5454
# test a query that should return the fallback response
5555
query = "Hello"
56-
# set use_cpu to True, as testing has no GPU calculations
57-
result = self.rag_system.retrieve(query, use_cpu=True)
56+
result = self.rag_system.retrieve(query)
5857
self.assertIsInstance(result, list)
5958
self.assertGreater(len(result), 0)
6059
self.assertEqual(len(result), 1) # should return one result for fallback
@@ -67,8 +66,7 @@ def test_retrieve_fallback(self):
6766
def test_retrieve_actual_response(self):
6867
# test a query that should return an actual response from the knowledge base
6968
query = "What is Defang?"
70-
# set use_cpu to True, as testing has no GPU calculations
71-
result = self.rag_system.retrieve(query, use_cpu=True)
69+
result = self.rag_system.retrieve(query)
7270
self.assertIsInstance(result, list)
7371
self.assertGreater(len(result), 0)
7472
self.assertLessEqual(len(result), 5) # should return up to max_docs (5)
@@ -81,8 +79,8 @@ def test_retrieve_actual_response(self):
8179
def test_compute_document_scores(self):
8280
query = "Does Defang have an MCP sample?"
8381
# get embeddings and move them to CPU, as testing has no GPU calculations
84-
query_embedding = self.rag_system.get_query_embedding(query, use_cpu=True)
85-
doc_embeddings = self.rag_system.get_doc_embeddings(use_cpu=True)
82+
query_embedding = self.rag_system.get_query_embedding(query)
83+
doc_embeddings = self.rag_system.get_doc_embeddings()
8684

8785
# call function and get results
8886
result = self.rag_system.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold=0.8)

0 commit comments

Comments
 (0)