Skip to content

Commit 46eea00

Browse files
committed
Revert "make cpu tensors the default for both testing and production"
This reverts commit 020e1b3.
1 parent 63e607c commit 46eea00

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

app/rag_system.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,14 @@ def embed_knowledge_base(self):
2727
def normalize_query(self, query):
2828
return query.lower().strip()
2929

30-
def get_query_embedding(self, query, use_cpu=True):
30+
def get_query_embedding(self, query, use_cpu=False):
3131
normalized_query = self.normalize_query(query)
3232
query_embedding = self.model.encode([normalized_query], convert_to_tensor=True)
33-
# Move the embeddings to the CPU to ensure compatibility with operations like cosine_similarity
3433
if use_cpu:
3534
query_embedding = query_embedding.cpu()
3635
return query_embedding
3736

38-
def get_doc_embeddings(self, use_cpu=True):
39-
# Move the embeddings to the CPU to ensure compatibility with operations like cosine_similarity
37+
def get_doc_embeddings(self, use_cpu=False):
4038
if use_cpu:
4139
return self.doc_embeddings.cpu()
4240
return self.doc_embeddings
@@ -64,9 +62,12 @@ def compute_document_scores(self, query_embedding, doc_embeddings, high_match_th
6462

6563
return result
6664

67-
def retrieve(self, query, similarity_threshold=0.7, high_match_threshold=0.8, max_docs=5):
68-
query_embedding = self.get_query_embedding(query)
69-
doc_embeddings = self.get_doc_embeddings()
65+
def retrieve(self, query, similarity_threshold=0.7, high_match_threshold=0.8, max_docs=5, use_cpu=False):
66+
# Note: Set use_cpu=True to run on CPU, which is useful for testing or environments without a GPU.
67+
# Set use_cpu=False to leverage GPU for better performance in production.
68+
69+
query_embedding = self.get_query_embedding(query, use_cpu)
70+
doc_embeddings = self.get_doc_embeddings(use_cpu)
7071

7172
doc_scores = self.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold)
7273
retrieved_docs = self.get_top_docs(doc_scores, similarity_threshold, max_docs)

app/test_rag_system.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def test_get_doc_embeddings(self):
5353
def test_retrieve_fallback(self):
5454
# test a query that should return the fallback response
5555
query = "Hello"
56-
result = self.rag_system.retrieve(query)
56+
# set use_cpu to True, as testing has no GPU calculations
57+
result = self.rag_system.retrieve(query, use_cpu=True)
5758
self.assertIsInstance(result, list)
5859
self.assertGreater(len(result), 0)
5960
self.assertEqual(len(result), 1) # should return one result for fallback
@@ -66,7 +67,8 @@ def test_retrieve_fallback(self):
6667
def test_retrieve_actual_response(self):
6768
# test a query that should return an actual response from the knowledge base
6869
query = "What is Defang?"
69-
result = self.rag_system.retrieve(query)
70+
# set use_cpu to True, as testing has no GPU calculations
71+
result = self.rag_system.retrieve(query, use_cpu=True)
7072
self.assertIsInstance(result, list)
7173
self.assertGreater(len(result), 0)
7274
self.assertLessEqual(len(result), 5) # should return up to max_docs (5)
@@ -79,8 +81,8 @@ def test_retrieve_actual_response(self):
7981
def test_compute_document_scores(self):
8082
query = "Does Defang have an MCP sample?"
8183
# get embeddings and move them to CPU, as testing has no GPU calculations
82-
query_embedding = self.rag_system.get_query_embedding(query)
83-
doc_embeddings = self.rag_system.get_doc_embeddings()
84+
query_embedding = self.rag_system.get_query_embedding(query, use_cpu=True)
85+
doc_embeddings = self.rag_system.get_doc_embeddings(use_cpu=True)
8486

8587
# call function and get results
8688
result = self.rag_system.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold=0.8)

0 commit comments

Comments
 (0)