Skip to content

Commit 0613c01

Browse files
committed
Fix tests to work with new constructor for Searchspace
1 parent e24c4cb commit 0613c01

File tree

5 files changed

+40
-15
lines changed

5 files changed

+40
-15
lines changed

kernel_tuner/searchspace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(
4747
if neighbor_method is not None and neighbor_method != "Hamming":
4848
self.__prepare_neighbors_index()
4949
if build_neighbors_index:
50-
self.neighbors_index = self.__build_neighbors_index(neighbor_method, max_threads)
50+
self.neighbors_index = self.__build_neighbors_index(neighbor_method)
5151

5252
def __build_searchspace(self, block_size_names: list, max_threads: int) -> Tuple[List[tuple], np.ndarray, dict, int]:
5353
"""compute valid configurations in a search space based on restrictions and max_threads, returns the searchspace, a dict of the searchspace for fast lookups and the size"""

test/strategies/test_common.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@ def fake_runner():
2626
tune_params = OrderedDict([("x", [1, 2, 3]), ("y", [4, 5, 6])])
2727

2828

29-
def test__cost_func():
29+
def test_cost_func():
3030
x = [1, 4]
31-
kernel_options = None
3231
tuning_options = Options(scaling=False, snap=False, tune_params=tune_params,
3332
restrictions=None, strategy_options={}, cache={}, unique_results={},
3433
objective="time", objective_higher_is_better=False, metrics=None)
@@ -39,12 +38,12 @@ def test__cost_func():
3938
assert time == 5
4039

4140
# check if restrictions are properly handled
42-
restrictions = ["False"]
41+
restrictions = lambda _: False
4342
tuning_options = Options(scaling=False, snap=False, tune_params=tune_params,
4443
restrictions=restrictions, strategy_options={},
4544
verbose=True, cache={}, unique_results={},
4645
objective="time", objective_higher_is_better=False, metrics=None)
47-
time = CostFunc(Searchspace(tune_params, None, 1024), tuning_options, runner)(x)
46+
time = CostFunc(Searchspace(tune_params, restrictions, 1024), tuning_options, runner)(x)
4847
assert time == sys.float_info.max
4948

5049

test/test_common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def test_get_bounds():
3939
random.shuffle(tune_params[k])
4040

4141
expected = [(0, 4), (0, 9900), (-11.2, 123.27)]
42-
answer = common.get_bounds(tune_params)
42+
searchspace = Searchspace(tune_params, None, 1024)
43+
cost_func = common.CostFunc(searchspace, None, None)
44+
answer = cost_func.get_bounds()
4345
assert answer == expected
4446

4547

test/test_runners.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def test_runner(env):
290290
[32]) # vector_add only has one tunable parameter (block_size_x)
291291

292292
# call the runner
293-
results, _ = runner.run(searchspace, kernel_options, tuning_options)
293+
results = runner.run(searchspace, tuning_options)
294294

295295
assert len(results) == 1
296296
assert results[0]['block_size_x'] == 32

test/test_searchspace.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def min_func(gpu1, gpu2, gpu3, gpu4):
5252
sort_tune_params["gpu1"] = list(range(num_layers))
5353
sort_tune_params["gpu2"] = list(range(num_layers))
5454
sort_tune_params["gpu3"] = list(range(num_layers))
55-
sort_tuning_options = Options(dict(restrictions=[], tune_params=sort_tune_params))
5655
searchspace_sort = Searchspace(sort_tune_params, [], max_threads)
5756

5857
def test_size():
@@ -73,8 +72,13 @@ def test_internal_representation():
7372

7473
def test_sort():
7574
"""test that the sort searchspace option works as expected"""
76-
simple_searchspace_sort = Searchspace(simple_tuning_options.tune_params, simple_tuning_options.restrictions, max_threads, sort=True, sort_last_param_first=False)
77-
assert simple_searchspace_sort.list == [
75+
simple_searchspace_sort = Searchspace(
76+
simple_tuning_options.tune_params,
77+
simple_tuning_options.restrictions,
78+
max_threads
79+
)
80+
81+
expected = [
7882
(1, 4, "string_1"),
7983
(1, 4, "string_2"),
8084
(1, 5.5, "string_1"),
@@ -89,7 +93,11 @@ def test_sort():
8993
(3, 5.5, "string_2"),
9094
]
9195

92-
assert simple_searchspace.sorted_list(sort_last_param_first=False) == expected
96+
# Check if lists match without considering order
97+
assert set(simple_searchspace_sort.list) == set(expected)
98+
99+
# Check if lists match, also considering order
100+
assert simple_searchspace_sort.sorted_list() == expected
93101

94102
sorted_list = searchspace_sort.sorted_list(sort_last_param_first=False)
95103
num_params = len(sorted_list[0])
@@ -102,8 +110,13 @@ def test_sort():
102110

103111
def test_sort_reversed():
104112
"""test that the sort searchspace option with the sort_last_param_first option enabled works as expected"""
105-
simple_searchspace_sort_reversed = Searchspace(simple_tuning_options, max_threads, sort=True, sort_last_param_first=True)
106-
assert simple_searchspace_sort_reversed.list == [
113+
simple_searchspace_sort_reversed = Searchspace(
114+
simple_tuning_options.tune_params,
115+
simple_tuning_options.restrictions,
116+
max_threads
117+
)
118+
119+
expected = [
107120
(1, 4, "string_1"),
108121
(2, 4, "string_1"),
109122
(3, 4, "string_1"),
@@ -118,6 +131,10 @@ def test_sort_reversed():
118131
(3, 5.5, "string_2"),
119132
]
120133

134+
# Check if lists match without considering order
135+
assert set(simple_searchspace_sort_reversed.list) == set(expected)
136+
137+
# Check if lists match, also considering order
121138
assert simple_searchspace_sort_reversed.sorted_list(sort_last_param_first=True) == expected
122139

123140
sorted_list = searchspace_sort.sorted_list(sort_last_param_first=True)
@@ -178,7 +195,8 @@ def test_random_sample():
178195

179196
def __test_neighbors_prebuilt(param_config: tuple, expected_neighbors: list, neighbor_method: str):
180197
simple_searchspace_prebuilt = Searchspace(
181-
simple_tuning_options,
198+
simple_tuning_options.tune_params,
199+
simple_tuning_options.restrictions,
182200
max_threads,
183201
build_neighbors_index=True,
184202
neighbor_method=neighbor_method,
@@ -271,7 +289,13 @@ def test_neighbors_fictious():
271289

272290
def test_neighbors_cached():
273291
"""test whether retrieving a set of neighbors twice returns the cached version"""
274-
simple_searchspace_duplicate = Searchspace(simple_tuning_options, max_threads, neighbor_method="Hamming")
292+
simple_searchspace_duplicate = Searchspace(
293+
simple_tuning_options.tune_params,
294+
simple_tuning_options.restrictions,
295+
max_threads,
296+
neighbor_method="Hamming"
297+
)
298+
275299
test_configs = simple_searchspace_duplicate.get_random_sample(10)
276300
for test_config in test_configs:
277301
assert not simple_searchspace_duplicate.are_neighbors_indices_cached(test_config)

0 commit comments

Comments
 (0)