Skip to content

Commit 7fcfacb

Browse files
committed
Various improvements to random neighbor performance
1 parent d1d653e commit 7fcfacb

File tree

2 files changed

+17
-39
lines changed

2 files changed

+17
-39
lines changed

kernel_tuner/searchspace.py

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -868,49 +868,22 @@ def __get_random_neighbor_adjacent(self, param_config: tuple) -> tuple:
868868
allowed_index_difference = 1
869869
allowed_values = [[v] for v in param_config]
870870
while allowed_index_difference <= max_index_difference:
871-
# get the param config indices where the difference is allowed_index_difference or less for each position
872-
matching_indices = (np.max(abs_index_difference, axis=1) <= allowed_index_difference).nonzero()[0]
871+
# get the param config indices where the difference is at most allowed_index_difference for each position
872+
matching_indices = list((np.max(abs_index_difference, axis=1) <= allowed_index_difference).nonzero()[0])
873873
# as the selected param config does not differ anywhere, remove it from the matches
874874
if param_config_index is not None:
875-
matching_indices = np.setdiff1d(matching_indices, [param_config_index], assume_unique=False)
875+
matching_indices.remove(param_config_index)
876876

877877
# if there are matching indices, return a random one
878878
if len(matching_indices) > 0:
879879
# get the random index from the matching indices
880-
random_neighbor_index = np.random.choice(matching_indices)
880+
random_neighbor_index = choice(matching_indices)
881881
return self.get_param_configs_at_indices([random_neighbor_index])[0]
882882

883883
# if there are no matching indices, increase the allowed index difference and start over
884884
allowed_index_difference += 1
885885
return None
886886

887-
# alternative implementation
888-
# # start at an index difference of 1, progressively increase - potentially expensive if there are no neighbors
889-
# allowed_index_difference = 1
890-
# allowed_values = [[v] for v in param_config]
891-
# while evaluated_configs < self.size:
892-
# # for each parameter, add the allowed values
893-
# for i, value in enumerate(param_config):
894-
# param_values = self.tune_params[i]
895-
# current_index = param_values.index(value)
896-
897-
# # add lower neighbor (if exists)
898-
# if current_index - allowed_index_difference >= 0:
899-
# allowed_values[i].append(param_values[current_index - allowed_index_difference])
900-
# neighbor_candidates.append(tuple(lower_neighbor))
901-
902-
# # add upper neighbor (if exists)
903-
# if current_index + allowed_index_difference < len(param_values):
904-
# allowed_values[i].append(param_values[current_index + allowed_index_difference])
905-
906-
# # create the random list of candidate neighbors (Cartesian product of allowed values)
907-
# from itertools import product
908-
# candidate_neighbors = product(*allowed_values)
909-
# for candidate in candidate_neighbors:
910-
# # check if the candidate has not been previously evaluated
911-
# # check if the candidate neighbors are valid
912-
# return None
913-
914887
def __get_neighbors_indices_strictlyadjacent(
915888
self, param_config_index: int = None, param_config: tuple = None
916889
) -> List[int]:
@@ -926,7 +899,7 @@ def __get_neighbors_indices_strictlyadjacent(
926899
matching_indices = (np.max(abs_index_difference, axis=1) <= 1).nonzero()[0]
927900
# as the selected param config does not differ anywhere, remove it from the matches
928901
if param_config_index is not None:
929-
matching_indices = np.setdiff1d(matching_indices, [param_config_index], assume_unique=False)
902+
matching_indices = np.setdiff1d(matching_indices, [param_config_index], assume_unique=True)
930903
return matching_indices
931904

932905
def __get_neighbors_indices_adjacent(self, param_config_index: int = None, param_config: tuple = None) -> List[int]:
@@ -962,7 +935,7 @@ def __get_neighbors_indices_adjacent(self, param_config_index: int = None, param
962935
)
963936
# as the selected param config does not differ anywhere, remove it from the matches
964937
if param_config_index is not None:
965-
matching_indices = np.setdiff1d(matching_indices, [param_config_index], assume_unique=False)
938+
matching_indices = np.setdiff1d(matching_indices, [param_config_index], assume_unique=True)
966939
return matching_indices
967940

968941
def __build_neighbors_index(self, neighbor_method) -> List[List[int]]:
@@ -1078,10 +1051,11 @@ def get_random_neighbor(self, param_config: tuple, neighbor_method=None) -> tupl
10781051
neighbor_method = self.neighbor_method
10791052

10801053
# find the random neighbor based on the method
1081-
if neighbor_method == "Hamming":
1082-
return self.__get_random_neighbor_hamming(param_config)
1083-
elif neighbor_method == "adjacent":
1054+
if neighbor_method == "adjacent":
10841055
return self.__get_random_neighbor_adjacent(param_config)
1056+
# elif neighbor_method == "Hamming":
1057+
# this implementation is not as efficient as just generating all neighbors
1058+
# return self.__get_random_neighbor_hamming(param_config)
10851059
else:
10861060
# not much performance to be gained for strictly-adjacent neighbors, just generate the neighbors
10871061
neighbors = self.get_neighbors(param_config, neighbor_method)

kernel_tuner/strategies/simulated_annealing.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,14 @@ def neighbor(pos, searchspace: Searchspace, constraint_aware=True):
116116

117117
def random_neighbor(pos, method):
118118
"""Helper method to return a random neighbor."""
119-
neighbors = searchspace.get_neighbors(pos, neighbor_method=method)
120-
if not neighbors:
119+
# neighbors = searchspace.get_neighbors(pos, neighbor_method=method)
120+
# if not neighbors:
121+
# return pos
122+
# return random.choice(neighbors)
123+
neighbor = searchspace.get_random_neighbor(pos, neighbor_method=method)
124+
if neighbor is None:
121125
return pos
122-
return random.choice(neighbors)
126+
return neighbor
123127

124128
size = len(pos)
125129

0 commit comments

Comments
 (0)