Skip to content

Commit ae486f6

Browse files
remove use_cpu args
as we will now always run on cpu--these are irrelevant.
1 parent 213a3d9 commit ae486f6

File tree

2 files changed

+12
-21
lines changed

2 files changed

+12
-21
lines changed

app/rag_system.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,14 @@ def embed_knowledge_base(self):
3030
def normalize_query(self, query):
3131
return query.lower().strip()
3232

33-
def get_query_embedding(self, query, use_cpu=True):
33+
def get_query_embedding(self, query):
3434
normalized_query = self.normalize_query(query)
3535
query_embedding = self.model.encode([normalized_query], convert_to_tensor=True)
36-
if use_cpu:
37-
query_embedding = query_embedding.cpu()
36+
query_embedding = query_embedding.cpu()
3837
return query_embedding
3938

40-
def get_doc_embeddings(self, use_cpu=True):
41-
if use_cpu:
42-
return self.doc_embeddings.cpu()
43-
return self.doc_embeddings
39+
def get_doc_embeddings(self):
40+
return self.doc_embeddings.cpu()
4441

4542
def compute_document_scores(self, query_embedding, doc_embeddings, high_match_threshold):
4643
text_similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
@@ -66,12 +63,9 @@ def compute_document_scores(self, query_embedding, doc_embeddings, high_match_th
6663

6764
return result
6865

69-
def retrieve(self, query, similarity_threshold=0.4, high_match_threshold=0.8, max_docs=5, use_cpu=True):
70-
# Note: Set use_cpu=True to run on CPU, which is useful for testing or environments without a GPU.
71-
# Set use_cpu=False to leverage GPU for better performance in production.
72-
73-
query_embedding = self.get_query_embedding(query, use_cpu)
74-
doc_embeddings = self.get_doc_embeddings(use_cpu)
66+
def retrieve(self, query, similarity_threshold=0.4, high_match_threshold=0.8, max_docs=5):
67+
query_embedding = self.get_query_embedding(query)
68+
doc_embeddings = self.get_doc_embeddings()
7569

7670
doc_scores = self.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold)
7771
retrieved_docs = self.get_top_docs(doc_scores, similarity_threshold, max_docs)

app/test_rag_system.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def test_get_doc_embeddings(self):
5353
def test_retrieve_fallback(self):
5454
# test a query that should return the fallback response
5555
query = "Hello"
56-
# set use_cpu to True, as testing has no GPU calculations
57-
result = self.rag_system.retrieve(query, use_cpu=True)
56+
result = self.rag_system.retrieve(query)
5857
self.assertIsInstance(result, list)
5958
self.assertGreater(len(result), 0)
6059
self.assertEqual(len(result), 1) # should return one result for fallback
@@ -67,8 +66,7 @@ def test_retrieve_fallback(self):
6766
def test_retrieve_actual_response(self):
6867
# test a query that should return an actual response from the knowledge base
6968
query = "What is Defang?"
70-
# set use_cpu to True, as testing has no GPU calculations
71-
result = self.rag_system.retrieve(query, use_cpu=True)
69+
result = self.rag_system.retrieve(query)
7270
self.assertIsInstance(result, list)
7371
self.assertGreater(len(result), 0)
7472
self.assertLessEqual(len(result), 5) # should return up to max_docs (5)
@@ -80,9 +78,8 @@ def test_retrieve_actual_response(self):
8078

8179
def test_compute_document_scores(self):
8280
query = "Does Defang have an MCP sample?"
83-
# get embeddings and move them to CPU, as testing has no GPU calculations
84-
query_embedding = self.rag_system.get_query_embedding(query, use_cpu=True)
85-
doc_embeddings = self.rag_system.get_doc_embeddings(use_cpu=True)
81+
query_embedding = self.rag_system.get_query_embedding(query)
82+
doc_embeddings = self.rag_system.get_doc_embeddings()
8683

8784
# call function and get results
8885
result = self.rag_system.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold=0.8)
@@ -105,4 +102,4 @@ def test_compute_document_scores(self):
105102
print("Test for compute_document_scores passed successfully!")
106103

107104
if __name__ == '__main__':
108-
unittest.main()
105+
unittest.main()

0 commit comments

Comments
 (0)