@@ -53,8 +53,7 @@ def test_get_doc_embeddings(self):
53
53
def test_retrieve_fallback (self ):
54
54
# test a query that should return the fallback response
55
55
query = "Hello"
56
- # set use_cpu to True, as testing has no GPU calculations
57
- result = self .rag_system .retrieve (query , use_cpu = True )
56
+ result = self .rag_system .retrieve (query )
58
57
self .assertIsInstance (result , list )
59
58
self .assertGreater (len (result ), 0 )
60
59
self .assertEqual (len (result ), 1 ) # should return one result for fallback
@@ -67,8 +66,7 @@ def test_retrieve_fallback(self):
67
66
def test_retrieve_actual_response (self ):
68
67
# test a query that should return an actual response from the knowledge base
69
68
query = "What is Defang?"
70
- # set use_cpu to True, as testing has no GPU calculations
71
- result = self .rag_system .retrieve (query , use_cpu = True )
69
+ result = self .rag_system .retrieve (query )
72
70
self .assertIsInstance (result , list )
73
71
self .assertGreater (len (result ), 0 )
74
72
self .assertLessEqual (len (result ), 5 ) # should return up to max_docs (5)
@@ -80,9 +78,8 @@ def test_retrieve_actual_response(self):
80
78
81
79
def test_compute_document_scores (self ):
82
80
query = "Does Defang have an MCP sample?"
83
- # get embeddings and move them to CPU, as testing has no GPU calculations
84
- query_embedding = self .rag_system .get_query_embedding (query , use_cpu = True )
85
- doc_embeddings = self .rag_system .get_doc_embeddings (use_cpu = True )
81
+ query_embedding = self .rag_system .get_query_embedding (query )
82
+ doc_embeddings = self .rag_system .get_doc_embeddings ()
86
83
87
84
# call function and get results
88
85
result = self .rag_system .compute_document_scores (query_embedding , doc_embeddings , high_match_threshold = 0.8 )
@@ -105,4 +102,4 @@ def test_compute_document_scores(self):
105
102
print ("Test for compute_document_scores passed successfully!" )
106
103
107
104
if __name__ == '__main__' :
108
- unittest .main ()
105
+ unittest .main ()
0 commit comments