Skip to content

Commit 9821156

Browse files
committed
Merge remote-tracking branch 'origin/searchspace_experiments' into constrained_optimization_tunable
2 parents 6cd8029 + 3c6b96e commit 9821156

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

kernel_tuner/searchspace.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(
9999
self.params_values_indices = None
100100
self.build_neighbors_index = build_neighbors_index
101101
self.solver_method = solver_method
102-
self.__neighbor_cache = dict()
102+
self.__neighbor_cache = { method: dict() for method in supported_neighbor_methods }
103103
self.neighbor_method = neighbor_method
104104
if (neighbor_method is not None or build_neighbors_index) and neighbor_method not in supported_neighbor_methods:
105105
raise ValueError(f"Neighbor method is {neighbor_method}, must be one of {supported_neighbor_methods}")
@@ -934,24 +934,25 @@ def get_neighbors_indices_no_cache(self, param_config: tuple, neighbor_method=No
934934
raise ValueError(f"The neighbor method {neighbor_method} is not in {supported_neighbor_methods}")
935935

936936
def get_neighbors_indices(self, param_config: tuple, neighbor_method=None) -> List[int]:
937-
"""Get the neighbors indices for a parameter configuration, possibly cached."""
938-
neighbors = self.__neighbor_cache.get(param_config, None)
937+
"""Get the neighbors indices for a parameter configuration, cached if requested before."""
938+
if neighbor_method is None:
939+
neighbor_method = self.neighbor_method
940+
if neighbor_method is None:
941+
raise ValueError("Neither the neighbor_method argument nor self.neighbor_method was set")
942+
neighbors = self.__neighbor_cache[neighbor_method].get(param_config, None)
939943
# if there are no cached neighbors, compute them
940944
if neighbors is None:
941945
neighbors = self.get_neighbors_indices_no_cache(param_config, neighbor_method)
942-
self.__neighbor_cache[param_config] = neighbors
943-
# if the neighbors were cached but the specified neighbor method was different than the one initially used to build the cache, throw an error
944-
elif (
945-
self.neighbor_method is not None and neighbor_method is not None and self.neighbor_method != neighbor_method
946-
):
947-
raise ValueError(
948-
f"The neighbor method {neighbor_method} differs from the intially set {self.neighbor_method}, can not use cached neighbors. Use 'get_neighbors_no_cache()' when mixing neighbor methods to avoid this."
949-
)
946+
self.__neighbor_cache[neighbor_method][param_config] = neighbors
950947
return neighbors
951948

952-
def are_neighbors_indices_cached(self, param_config: tuple) -> bool:
949+
def are_neighbors_indices_cached(self, param_config: tuple, neighbor_method=None) -> bool:
953950
"""Returns true if the neighbor indices are in the cache, false otherwise."""
954-
return param_config in self.__neighbor_cache
951+
if neighbor_method is None:
952+
neighbor_method = self.neighbor_method
953+
if neighbor_method is None:
954+
raise ValueError("Neither the neighbor_method argument nor self.neighbor_method was set")
955+
return param_config in self.__neighbor_cache[neighbor_method]
955956

956957
def get_neighbors_no_cache(self, param_config: tuple, neighbor_method=None) -> List[tuple]:
957958
"""Get the neighbors for a parameter configuration (does not check running cache, useful when mixing neighbor methods)."""

test/test_searchspace.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,34 @@ def test_neighbors_cached():
320320
assert neighbors == neighbors_2
321321

322322

323+
def test_neighbors_cached_mixed_methods():
324+
"""Test whether retrieving a set of neighbors with one method after another yields the correct neighbors."""
325+
simple_searchspace_duplicate = Searchspace(
326+
simple_tuning_options.tune_params,
327+
simple_tuning_options.restrictions,
328+
max_threads,
329+
)
330+
331+
test_configs = simple_searchspace_duplicate.get_random_sample(5)
332+
for test_config in test_configs:
333+
assert not simple_searchspace_duplicate.are_neighbors_indices_cached(test_config, "Hamming")
334+
neighbors_hamming = simple_searchspace_duplicate.get_neighbors(test_config, "Hamming")
335+
assert simple_searchspace_duplicate.are_neighbors_indices_cached(test_config, "Hamming")
336+
337+
# now switch to a different method
338+
neighbors_strictlyadjacent = simple_searchspace_duplicate.get_neighbors(test_config, "strictly-adjacent")
339+
neighbors_strictlyadjacent_no_cache = simple_searchspace_duplicate.get_neighbors_no_cache(test_config, "strictly-adjacent")
340+
341+
neighbors_adjacent = simple_searchspace_duplicate.get_neighbors(test_config, "adjacent")
342+
neighbors_adjacent_no_cache = simple_searchspace_duplicate.get_neighbors_no_cache(test_config, "adjacent")
343+
344+
# check that the neighbors are as expected
345+
assert neighbors_strictlyadjacent == neighbors_strictlyadjacent_no_cache
346+
assert neighbors_adjacent == neighbors_adjacent_no_cache
347+
assert neighbors_hamming != neighbors_strictlyadjacent
348+
assert neighbors_hamming != neighbors_adjacent
349+
350+
323351
def test_param_neighbors():
324352
"""Test whether for a given parameter configuration and index the correct neighboring parameters are returned."""
325353
test_config = tuple([1.5, 4, "string_1"])

0 commit comments

Comments
 (0)