Skip to content

Commit 8954a46

Browse files
committed
Implemented building multiple neighbor index caches, optional parameter to build full cache
1 parent 5ec2d60 commit 8954a46

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

kernel_tuner/searchspace.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
framework_l = framework.lower()
8080
restrictions = restrictions if restrictions is not None else []
8181
self.tune_params = tune_params
82+
self.original_tune_params = tune_params.copy() if hasattr(tune_params, "copy") else tune_params
8283
self.max_threads = max_threads
8384
self.block_size_names = block_size_names
8485
self._tensorspace = None
@@ -92,6 +93,7 @@ def __init__(
9293
self._map_tensor_to_param = {}
9394
self._map_param_to_tensor = {}
9495
self.restrictions = restrictions.copy() if hasattr(restrictions, "copy") else restrictions
96+
self.original_restrictions = restrictions.copy() if hasattr(restrictions, "copy") else restrictions
9597
# the searchspace can add commonly used constraints (e.g. maxprod(blocks) <= maxthreads)
9698
self._modified_restrictions = restrictions.copy() if hasattr(restrictions, "copy") else restrictions
9799
self.param_names = list(self.tune_params.keys())
@@ -100,6 +102,7 @@ def __init__(
100102
self.build_neighbors_index = build_neighbors_index
101103
self.solver_method = solver_method
102104
self.__neighbor_cache = { method: dict() for method in supported_neighbor_methods }
105+
self.neighbors_index = dict()
103106
self.neighbor_method = neighbor_method
104107
if (neighbor_method is not None or build_neighbors_index) and neighbor_method not in supported_neighbor_methods:
105108
raise ValueError(f"Neighbor method is {neighbor_method}, must be one of {supported_neighbor_methods}")
@@ -175,7 +178,7 @@ def __init__(
175178
if neighbor_method is not None and neighbor_method != "Hamming":
176179
self.__prepare_neighbors_index()
177180
if build_neighbors_index:
178-
self.neighbors_index = self.__build_neighbors_index(neighbor_method)
181+
self.neighbors_index[neighbor_method] = self.__build_neighbors_index(neighbor_method)
179182

180183
# def __build_searchspace_ortools(self, block_size_names: list, max_threads: int) -> Tuple[List[tuple], np.ndarray, dict, int]:
181184
# # Based on https://developers.google.com/optimization/cp/cp_solver#python_2
@@ -452,7 +455,8 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
452455
# add the user-specified restrictions as constraints on the parameter space
453456
if not isinstance(self.restrictions, (list, tuple)):
454457
self.restrictions = [self.restrictions]
455-
self.restrictions = convert_constraint_lambdas(self.restrictions)
458+
if any(not isinstance(restriction, (Constraint, FunctionConstraint, str)) for restriction in self.restrictions):
459+
self.restrictions = convert_constraint_lambdas(self.restrictions)
456460
parameter_space = self.__add_restrictions(parameter_space)
457461

458462
# add the default blocksize threads restrictions last, because it is unlikely to reduce the parameter space by much
@@ -901,24 +905,25 @@ def get_random_sample(self, num_samples: int) -> List[tuple]:
901905
num_samples = self.size
902906
return self.get_param_configs_at_indices(self.get_random_sample_indices(num_samples))
903907

904-
def get_neighbors_indices_no_cache(self, param_config: tuple, neighbor_method=None) -> List[int]:
908+
def get_neighbors_indices_no_cache(self, param_config: tuple, neighbor_method=None, build_full_cache=False) -> List[int]:
905909
"""Get the neighbors indices for a parameter configuration (does not check running cache, useful when mixing neighbor methods)."""
906910
param_config_index = self.get_param_config_index(param_config)
907911

908-
# this is the simplest case, just return the cached value
909-
if self.build_neighbors_index and param_config_index is not None:
910-
if neighbor_method is not None and neighbor_method != self.neighbor_method:
911-
raise ValueError(
912-
f"The neighbor method {neighbor_method} differs from the neighbor method {self.neighbor_method} initially used for indexing"
913-
)
914-
return self.neighbors_index[param_config_index]
915-
916912
# check if there is a neighbor method to use
917913
if neighbor_method is None:
918914
if self.neighbor_method is None:
919915
raise ValueError("Neither the neighbor_method argument nor self.neighbor_method was set")
920916
neighbor_method = self.neighbor_method
921917

918+
# this is the simplest case, just return the cached value
919+
if param_config_index is not None:
920+
if neighbor_method in self.neighbors_index:
921+
return self.neighbors_index[neighbor_method][param_config_index]
922+
elif build_full_cache:
923+
# build the neighbors index for the given neighbor method
924+
self.neighbors_index[neighbor_method] = self.__build_neighbors_index(neighbor_method)
925+
return self.neighbors_index[neighbor_method][param_config_index]
926+
922927
if neighbor_method == "Hamming":
923928
return self.__get_neighbors_indices_hamming(param_config)
924929

@@ -933,7 +938,7 @@ def get_neighbors_indices_no_cache(self, param_config: tuple, neighbor_method=No
933938
return self.__get_neighbors_indices_adjacent(param_config_index, param_config)
934939
raise ValueError(f"The neighbor method {neighbor_method} is not in {supported_neighbor_methods}")
935940

936-
def get_neighbors_indices(self, param_config: tuple, neighbor_method=None) -> List[int]:
941+
def get_neighbors_indices(self, param_config: tuple, neighbor_method=None, build_full_cache=False) -> List[int]:
937942
"""Get the neighbors indices for a parameter configuration, cached if requested before."""
938943
if neighbor_method is None:
939944
neighbor_method = self.neighbor_method
@@ -942,7 +947,7 @@ def get_neighbors_indices(self, param_config: tuple, neighbor_method=None) -> Li
942947
neighbors = self.__neighbor_cache[neighbor_method].get(param_config, None)
943948
# if there are no cached neighbors, compute them
944949
if neighbors is None:
945-
neighbors = self.get_neighbors_indices_no_cache(param_config, neighbor_method)
950+
neighbors = self.get_neighbors_indices_no_cache(param_config, neighbor_method, build_full_cache)
946951
self.__neighbor_cache[neighbor_method][param_config] = neighbors
947952
return neighbors
948953

@@ -958,9 +963,9 @@ def get_neighbors_no_cache(self, param_config: tuple, neighbor_method=None) -> L
958963
"""Get the neighbors for a parameter configuration (does not check running cache, useful when mixing neighbor methods)."""
959964
return self.get_param_configs_at_indices(self.get_neighbors_indices_no_cache(param_config, neighbor_method))
960965

961-
def get_neighbors(self, param_config: tuple, neighbor_method=None) -> List[tuple]:
966+
def get_neighbors(self, param_config: tuple, neighbor_method=None, build_full_cache=False) -> List[tuple]:
962967
"""Get the neighbors for a parameter configuration."""
963-
return self.get_param_configs_at_indices(self.get_neighbors_indices(param_config, neighbor_method))
968+
return self.get_param_configs_at_indices(self.get_neighbors_indices(param_config, neighbor_method, build_full_cache))
964969

965970
def get_param_neighbors(self, param_config: tuple, index: int, neighbor_method: str, randomize: bool) -> list:
966971
"""Get the neighboring parameters at an index."""

0 commit comments

Comments
 (0)