Skip to content

Commit c6917bd

Browse files
committed
Implemented get_random_neighbor and helper functions in Searchspace, which are much faster to find random neighbors than looking up all neighbors and selecting a random one
1 parent 214865f commit c6917bd

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed

kernel_tuner/searchspace.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,91 @@ def __get_neighbors_indices_hamming(self, param_config: tuple) -> List[int]:
826826
matching_indices = (num_matching_params == self.num_params - 1).nonzero()[0]
827827
return matching_indices
828828

829+
def __get_random_neighbor_hamming(self, param_config: tuple) -> tuple:
830+
"""Get a random neighbor at 1 Hamming distance from the parameter configuration."""
831+
arr = self.get_list_numpy()
832+
target = np.array(param_config)
833+
assert arr[0].shape == target.shape
834+
835+
# find the first row that differs from the target in exactly one column, return as soon as one is found
836+
random_order_indices = np.random.permutation(arr.shape[0])
837+
for i in random_order_indices:
838+
# assert arr[i].shape == target.shape, f"Row {i} shape {arr[i].shape} does not match target shape {target.shape}"
839+
if np.count_nonzero(arr[i] != target) == 1:
840+
return self.get_param_configs_at_indices([i])[0]
841+
return None
842+
843+
def __get_random_neighbor_adjacent(self, param_config: tuple) -> tuple:
844+
"""Get an approximately random adjacent neighbor of the parameter configuration."""
845+
# NOTE: this is not truly random as we only progressively increase the allowed index difference if no neighbors are found, but much faster than generating all neighbors
846+
847+
# get the indices of the parameter values
848+
if self.params_values_indices is None:
849+
self.__prepare_neighbors_index()
850+
param_config_index = self.get_param_config_index(param_config)
851+
param_config_value_indices = (
852+
self.get_param_indices(param_config)
853+
if param_config_index is None
854+
else self.params_values_indices[param_config_index]
855+
)
856+
max_index_difference_per_param = [max(len(self.params_values[p]) - 1 - i, i) for p, i in enumerate(param_config_value_indices)]
857+
858+
# calculate the absolute difference between the parameter value indices
859+
abs_index_difference = np.abs(self.params_values_indices - param_config_value_indices)
860+
861+
# calculate the difference between the parameter value indices
862+
index_difference = np.abs(self.params_values_indices - param_config_value_indices)
863+
# transpose to get the param indices difference per parameter instead of per param config
864+
index_difference_transposed = index_difference.transpose()
865+
866+
# start at an index difference of 1, progressively increase - potentially expensive if there are no neighbors until very late
867+
max_index_difference = max(max_index_difference_per_param)
868+
allowed_index_difference = 1
869+
allowed_values = [[v] for v in param_config]
870+
while allowed_index_difference <= max_index_difference:
871+
# get the param config indices where the difference is allowed_index_difference or less for each position
872+
matching_indices = (np.max(abs_index_difference, axis=1) <= allowed_index_difference).nonzero()[0]
873+
# as the selected param config does not differ anywhere, remove it from the matches
874+
if param_config_index is not None:
875+
matching_indices = np.setdiff1d(matching_indices, [param_config_index], assume_unique=False)
876+
877+
# if there are matching indices, return a random one
878+
if len(matching_indices) > 0:
879+
# get the random index from the matching indices
880+
random_neighbor_index = np.random.choice(matching_indices)
881+
return self.get_param_configs_at_indices([random_neighbor_index])[0]
882+
883+
# if there are no matching indices, increase the allowed index difference and start over
884+
allowed_index_difference += 1
885+
return None
886+
887+
# alternative implementation
888+
# # start at an index difference of 1, progressively increase - potentially expensive if there are no neighbors
889+
# allowed_index_difference = 1
890+
# allowed_values = [[v] for v in param_config]
891+
# while evaluated_configs < self.size:
892+
# # for each parameter, add the allowed values
893+
# for i, value in enumerate(param_config):
894+
# param_values = self.tune_params[i]
895+
# current_index = param_values.index(value)
896+
897+
# # add lower neighbor (if exists)
898+
# if current_index - allowed_index_difference >= 0:
899+
# allowed_values[i].append(param_values[current_index - allowed_index_difference])
900+
# neighbor_candidates.append(tuple(lower_neighbor))
901+
902+
# # add upper neighbor (if exists)
903+
# if current_index + allowed_index_difference < len(param_values):
904+
# allowed_values[i].append(param_values[current_index + allowed_index_difference])
905+
906+
# # create the random list of candidate neighbors (Cartesian product of allowed values)
907+
# from itertools import product
908+
# candidate_neighbors = product(*allowed_values)
909+
# for candidate in candidate_neighbors:
910+
# # check if the candidate has not been previously evaluated
911+
# # check if the candidate neighbors are valid
912+
# return None
913+
829914
def __get_neighbors_indices_strictlyadjacent(
830915
self, param_config_index: int = None, param_config: tuple = None
831916
) -> List[int]:
@@ -982,6 +1067,28 @@ def get_neighbors(self, param_config: tuple, neighbor_method=None, build_full_ca
9821067
"""Get the neighbors for a parameter configuration."""
9831068
return self.get_param_configs_at_indices(self.get_neighbors_indices(param_config, neighbor_method, build_full_cache))
9841069

1070+
def get_random_neighbor(self, param_config: tuple, neighbor_method=None) -> tuple:
1071+
"""Get an approximately random neighbor for a parameter configuration. Much faster than taking a random choice of all neighbors, but does not build cache."""
1072+
if self.are_neighbors_indices_cached(param_config, neighbor_method):
1073+
neighbors = self.get_neighbors(param_config, neighbor_method)
1074+
return choice(neighbors)
1075+
else:
1076+
# check if there is a neighbor method to use
1077+
if neighbor_method is None:
1078+
neighbor_method = self.neighbor_method
1079+
1080+
# find the random neighbor based on the method
1081+
if neighbor_method == "Hamming":
1082+
return self.__get_random_neighbor_hamming(param_config)
1083+
elif neighbor_method == "adjacent":
1084+
return self.__get_random_neighbor_adjacent(param_config)
1085+
else:
1086+
# not much performance to be gained for strictly-adjacent neighbors, just generate the neighbors
1087+
neighbors = self.get_neighbors(param_config, neighbor_method)
1088+
if len(neighbors) == 0:
1089+
return None
1090+
return choice(neighbors)
1091+
9851092
def get_param_neighbors(self, param_config: tuple, index: int, neighbor_method: str, randomize: bool) -> list:
9861093
"""Get the neighboring parameters at an index."""
9871094
original_value = param_config[index]

0 commit comments

Comments
 (0)