Skip to content

Commit 53d2ee2

Browse files
committed
Fix constructor of Searchspace that was lost due to bad merge
1 parent 58f9faa commit 53d2ee2

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

kernel_tuner/searchspace.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ class Searchspace:
1515
"""Class that offers the search space to strategies"""
1616

1717
def __init__(
18-
self,
19-
tuning_options: dict,
20-
max_threads: int,
21-
build_neighbors_index=False,
22-
neighbor_method=None,
23-
sort=False,
24-
sort_last_param_first=False,
18+
self,
19+
tune_params: dict,
20+
restrictions,
21+
max_threads: int,
22+
block_size_names=default_block_size_names,
23+
build_neighbors_index=False,
24+
neighbor_method=None
2525
) -> None:
2626
"""Build a searchspace using the variables and constraints.
2727
Optionally build the neighbors index - only faster if you repeatedly look up neighbors. Methods:
@@ -47,9 +47,9 @@ def __init__(
4747
if neighbor_method is not None and neighbor_method != "Hamming":
4848
self.__prepare_neighbors_index()
4949
if build_neighbors_index:
50-
self.neighbors_index = self.__build_neighbors_index(neighbor_method)
50+
self.neighbors_index = self.__build_neighbors_index(neighbor_method, max_threads)
5151

52-
def __build_searchspace(self, sort: bool, sort_last_param_first: bool) -> Tuple[List[tuple], np.ndarray, dict, int]:
52+
def __build_searchspace(self, block_size_names: list, max_threads: int) -> Tuple[List[tuple], np.ndarray, dict, int]:
5353
"""compute valid configurations in a search space based on restrictions and max_threads, returns the searchspace, a dict of the searchspace for fast lookups and the size"""
5454

5555
# instantiate the parameter space with all the variables

test/test_searchspace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def min_func(gpu1, gpu2, gpu3, gpu4):
5353
sort_tune_params["gpu2"] = list(range(num_layers))
5454
sort_tune_params["gpu3"] = list(range(num_layers))
5555
sort_tuning_options = Options(dict(restrictions=[], tune_params=sort_tune_params))
56-
searchspace_sort = Searchspace(sort_tuning_options.tune_params, sort_tuning_options.restrictions, max_threads)
56+
searchspace_sort = Searchspace(sort_tune_params, [], max_threads)
5757

5858
def test_size():
5959
"""test that the searchspace after applying restrictions is the expected size"""
@@ -73,7 +73,7 @@ def test_internal_representation():
7373

7474
def test_sort():
7575
"""test that the sort searchspace option works as expected"""
76-
simple_searchspace_sort = Searchspace(simple_tuning_options, max_threads, sort=True, sort_last_param_first=False)
76+
simple_searchspace_sort = Searchspace(simple_tuning_options.tune_params, simple_tuning_options.restrictions, max_threads, sort=True, sort_last_param_first=False)
7777
assert simple_searchspace_sort.list == [
7878
(1, 4, "string_1"),
7979
(1, 4, "string_2"),

0 commit comments

Comments
 (0)