Skip to content

Commit 3c6b96e

Browse files
committed
The Searchspace neighbor cache now caches for all neighbor methods separately
1 parent 6c804ff commit 3c6b96e

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

kernel_tuner/searchspace.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
self.params_values_indices = None
6565
self.build_neighbors_index = build_neighbors_index
6666
self.solver_method = solver_method
67-
self.__neighbor_cache = dict()
67+
self.__neighbor_cache = { method: dict() for method in supported_neighbor_methods }
6868
self.neighbor_method = neighbor_method
6969
if (neighbor_method is not None or build_neighbors_index) and neighbor_method not in supported_neighbor_methods:
7070
raise ValueError(f"Neighbor method is {neighbor_method}, must be one of {supported_neighbor_methods}")
@@ -758,24 +758,25 @@ def get_neighbors_indices_no_cache(self, param_config: tuple, neighbor_method=No
758758
raise ValueError(f"The neighbor method {neighbor_method} is not in {supported_neighbor_methods}")
759759

760760
def get_neighbors_indices(self, param_config: tuple, neighbor_method=None) -> List[int]:
761-
"""Get the neighbors indices for a parameter configuration, possibly cached."""
762-
neighbors = self.__neighbor_cache.get(param_config, None)
761+
"""Get the neighbors indices for a parameter configuration, cached if requested before."""
762+
if neighbor_method is None:
763+
neighbor_method = self.neighbor_method
764+
if neighbor_method is None:
765+
raise ValueError("Neither the neighbor_method argument nor self.neighbor_method was set")
766+
neighbors = self.__neighbor_cache[neighbor_method].get(param_config, None)
763767
# if there are no cached neighbors, compute them
764768
if neighbors is None:
765769
neighbors = self.get_neighbors_indices_no_cache(param_config, neighbor_method)
766-
self.__neighbor_cache[param_config] = neighbors
767-
# if the neighbors were cached but the specified neighbor method was different than the one initially used to build the cache, throw an error
768-
elif (
769-
self.neighbor_method is not None and neighbor_method is not None and self.neighbor_method != neighbor_method
770-
):
771-
raise ValueError(
772-
f"The neighbor method {neighbor_method} differs from the intially set {self.neighbor_method}, can not use cached neighbors. Use 'get_neighbors_no_cache()' when mixing neighbor methods to avoid this."
773-
)
770+
self.__neighbor_cache[neighbor_method][param_config] = neighbors
774771
return neighbors
775772

776-
def are_neighbors_indices_cached(self, param_config: tuple) -> bool:
773+
def are_neighbors_indices_cached(self, param_config: tuple, neighbor_method=None) -> bool:
777774
"""Returns true if the neighbor indices are in the cache, false otherwise."""
778-
return param_config in self.__neighbor_cache
775+
if neighbor_method is None:
776+
neighbor_method = self.neighbor_method
777+
if neighbor_method is None:
778+
raise ValueError("Neither the neighbor_method argument nor self.neighbor_method was set")
779+
return param_config in self.__neighbor_cache[neighbor_method]
779780

780781
def get_neighbors_no_cache(self, param_config: tuple, neighbor_method=None) -> List[tuple]:
781782
"""Get the neighbors for a parameter configuration (does not check running cache, useful when mixing neighbor methods)."""

test/test_searchspace.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,34 @@ def test_neighbors_cached():
316316
assert neighbors == neighbors_2
317317

318318

319+
def test_neighbors_cached_mixed_methods():
320+
"""Test whether retrieving a set of neighbors with one method after another yields the correct neighbors."""
321+
simple_searchspace_duplicate = Searchspace(
322+
simple_tuning_options.tune_params,
323+
simple_tuning_options.restrictions,
324+
max_threads,
325+
)
326+
327+
test_configs = simple_searchspace_duplicate.get_random_sample(5)
328+
for test_config in test_configs:
329+
assert not simple_searchspace_duplicate.are_neighbors_indices_cached(test_config, "Hamming")
330+
neighbors_hamming = simple_searchspace_duplicate.get_neighbors(test_config, "Hamming")
331+
assert simple_searchspace_duplicate.are_neighbors_indices_cached(test_config, "Hamming")
332+
333+
# now switch to a different method
334+
neighbors_strictlyadjacent = simple_searchspace_duplicate.get_neighbors(test_config, "strictly-adjacent")
335+
neighbors_strictlyadjacent_no_cache = simple_searchspace_duplicate.get_neighbors_no_cache(test_config, "strictly-adjacent")
336+
337+
neighbors_adjacent = simple_searchspace_duplicate.get_neighbors(test_config, "adjacent")
338+
neighbors_adjacent_no_cache = simple_searchspace_duplicate.get_neighbors_no_cache(test_config, "adjacent")
339+
340+
# check that the neighbors are as expected
341+
assert neighbors_strictlyadjacent == neighbors_strictlyadjacent_no_cache
342+
assert neighbors_adjacent == neighbors_adjacent_no_cache
343+
assert neighbors_hamming != neighbors_strictlyadjacent
344+
assert neighbors_hamming != neighbors_adjacent
345+
346+
319347
def test_param_neighbors():
320348
"""Test whether for a given parameter configuration and index the correct neighboring parameters are returned."""
321349
test_config = tuple([1.5, 4, "string_1"])

0 commit comments

Comments
 (0)