@@ -52,7 +52,6 @@ def min_func(gpu1, gpu2, gpu3, gpu4):
5252sort_tune_params ["gpu1" ] = list (range (num_layers ))
5353sort_tune_params ["gpu2" ] = list (range (num_layers ))
5454sort_tune_params ["gpu3" ] = list (range (num_layers ))
55- sort_tuning_options = Options (dict (restrictions = [], tune_params = sort_tune_params ))
5655searchspace_sort = Searchspace (sort_tune_params , [], max_threads )
5756
5857def test_size ():
@@ -73,8 +72,13 @@ def test_internal_representation():
7372
7473def test_sort ():
7574 """test that the sort searchspace option works as expected"""
76- simple_searchspace_sort = Searchspace (simple_tuning_options .tune_params , simple_tuning_options .restrictions , max_threads , sort = True , sort_last_param_first = False )
77- assert simple_searchspace_sort .list == [
75+ simple_searchspace_sort = Searchspace (
76+ simple_tuning_options .tune_params ,
77+ simple_tuning_options .restrictions ,
78+ max_threads
79+ )
80+
81+ expected = [
7882 (1 , 4 , "string_1" ),
7983 (1 , 4 , "string_2" ),
8084 (1 , 5.5 , "string_1" ),
@@ -89,7 +93,11 @@ def test_sort():
8993 (3 , 5.5 , "string_2" ),
9094 ]
9195
92- assert simple_searchspace .sorted_list (sort_last_param_first = False ) == expected
96+ # Check if lists match without considering order
97+ assert set (simple_searchspace_sort .list ) == set (expected )
98+
99+ # Check if lists match, also considering order
100+ assert simple_searchspace_sort .sorted_list () == expected
93101
94102 sorted_list = searchspace_sort .sorted_list (sort_last_param_first = False )
95103 num_params = len (sorted_list [0 ])
@@ -102,8 +110,13 @@ def test_sort():
102110
103111def test_sort_reversed ():
104112 """test that the sort searchspace option with the sort_last_param_first option enabled works as expected"""
105- simple_searchspace_sort_reversed = Searchspace (simple_tuning_options , max_threads , sort = True , sort_last_param_first = True )
106- assert simple_searchspace_sort_reversed .list == [
113+ simple_searchspace_sort_reversed = Searchspace (
114+ simple_tuning_options .tune_params ,
115+ simple_tuning_options .restrictions ,
116+ max_threads
117+ )
118+
119+ expected = [
107120 (1 , 4 , "string_1" ),
108121 (2 , 4 , "string_1" ),
109122 (3 , 4 , "string_1" ),
@@ -118,6 +131,10 @@ def test_sort_reversed():
118131 (3 , 5.5 , "string_2" ),
119132 ]
120133
134+ # Check if lists match without considering order
135+ assert set (simple_searchspace_sort_reversed .list ) == set (expected )
136+
137+ # Check if lists match, also considering order
121138 assert simple_searchspace_sort_reversed .sorted_list (sort_last_param_first = True ) == expected
122139
123140 sorted_list = searchspace_sort .sorted_list (sort_last_param_first = True )
@@ -178,7 +195,8 @@ def test_random_sample():
178195
179196def __test_neighbors_prebuilt (param_config : tuple , expected_neighbors : list , neighbor_method : str ):
180197 simple_searchspace_prebuilt = Searchspace (
181- simple_tuning_options ,
198+ simple_tuning_options .tune_params ,
199+ simple_tuning_options .restrictions ,
182200 max_threads ,
183201 build_neighbors_index = True ,
184202 neighbor_method = neighbor_method ,
@@ -271,7 +289,13 @@ def test_neighbors_fictious():
271289
272290def test_neighbors_cached ():
273291 """test whether retrieving a set of neighbors twice returns the cached version"""
274- simple_searchspace_duplicate = Searchspace (simple_tuning_options , max_threads , neighbor_method = "Hamming" )
292+ simple_searchspace_duplicate = Searchspace (
293+ simple_tuning_options .tune_params ,
294+ simple_tuning_options .restrictions ,
295+ max_threads ,
296+ neighbor_method = "Hamming"
297+ )
298+
275299 test_configs = simple_searchspace_duplicate .get_random_sample (10 )
276300 for test_config in test_configs :
277301 assert not simple_searchspace_duplicate .are_neighbors_indices_cached (test_config )
0 commit comments