@@ -53,8 +53,7 @@ def test_get_doc_embeddings(self):
5353 def test_retrieve_fallback (self ):
5454 # test a query that should return the fallback response
5555 query = "Hello"
56- # set use_cpu to True, as testing has no GPU calculations
57- result = self .rag_system .retrieve (query , use_cpu = True )
56+ result = self .rag_system .retrieve (query )
5857 self .assertIsInstance (result , list )
5958 self .assertGreater (len (result ), 0 )
6059 self .assertEqual (len (result ), 1 ) # should return one result for fallback
@@ -67,8 +66,7 @@ def test_retrieve_fallback(self):
6766 def test_retrieve_actual_response (self ):
6867 # test a query that should return an actual response from the knowledge base
6968 query = "What is Defang?"
70- # set use_cpu to True, as testing has no GPU calculations
71- result = self .rag_system .retrieve (query , use_cpu = True )
69+ result = self .rag_system .retrieve (query )
7270 self .assertIsInstance (result , list )
7371 self .assertGreater (len (result ), 0 )
7472 self .assertLessEqual (len (result ), 5 ) # should return up to max_docs (5)
@@ -80,9 +78,8 @@ def test_retrieve_actual_response(self):
8078
8179 def test_compute_document_scores (self ):
8280 query = "Does Defang have an MCP sample?"
83- # get embeddings and move them to CPU, as testing has no GPU calculations
84- query_embedding = self .rag_system .get_query_embedding (query , use_cpu = True )
85- doc_embeddings = self .rag_system .get_doc_embeddings (use_cpu = True )
81+ query_embedding = self .rag_system .get_query_embedding (query )
82+ doc_embeddings = self .rag_system .get_doc_embeddings ()
8683
8784 # call function and get results
8885 result = self .rag_system .compute_document_scores (query_embedding , doc_embeddings , high_match_threshold = 0.8 )
@@ -105,4 +102,4 @@ def test_compute_document_scores(self):
105102 print ("Test for compute_document_scores passed successfully!" )
106103
107104if __name__ == '__main__' :
108- unittest .main ()
105+ unittest .main ()
0 commit comments