Skip to content

Commit 1a2bd44

Browse files
committed
Add new Hamming-adjacent neighborhood method to Searchspace
1 parent b4cd21d commit 1a2bd44

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

kernel_tuner/searchspace.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from kernel_tuner.util import check_restrictions as check_instance_restrictions
2121
from kernel_tuner.util import compile_restrictions, default_block_size_names
2222

23-
supported_neighbor_methods = ["strictly-adjacent", "adjacent", "Hamming"]
23+
supported_neighbor_methods = ["strictly-adjacent", "adjacent", "Hamming", "Hamming-adjacent"]
2424

2525

2626
class Searchspace:
@@ -44,6 +44,7 @@ def __init__(
4444
strictly-adjacent: differs +1 or -1 parameter index value for each parameter
4545
adjacent: picks closest parameter value in both directions for each parameter
4646
Hamming: any parameter config with 1 different parameter value is a neighbor
47+
Hamming-adjacent: differs +1 or -1 parameter index value for exactly 1 parameter.
4748
Optionally sort the searchspace by the order in which the parameter values were specified. By default, sort goes from first to last parameter, to reverse this use sort_last_param_first.
4849
"""
4950
# set the object attributes using the arguments
@@ -552,6 +553,45 @@ def __get_neighbors_indices_hamming(self, param_config: tuple) -> List[int]:
552553
matching_indices = (num_matching_params == self.num_params - 1).nonzero()[0]
553554
return matching_indices
554555

556+
def __get_neighbors_indices_hammingadjacent(self, param_config_index: int = None, param_config: tuple = None) -> List[int]:
557+
"""Get the neighbors using adjacent distance from the parameter configuration (parameter index absolute difference >= 1)."""
558+
param_config_value_indices = (
559+
self.get_param_indices(param_config)
560+
if param_config_index is None
561+
else self.params_values_indices[param_config_index]
562+
)
563+
564+
# compute boolean mask for all configuration that differ at exactly one parameter (Hamming distance == 1)
565+
hamming_mask = np.count_nonzero(self.params_values_indices != param_config_value_indices, axis=1) == 1
566+
567+
# get the configuration indices of the hamming neighbors
568+
hamming_indices, = np.nonzero(hamming_mask)
569+
570+
# for the hamming neighbors, calculate the difference between parameter value indices
571+
hamming_index_difference = self.params_values_indices[hamming_mask] - param_config_value_indices
572+
573+
# for each parameter get the closest upper and lower parameter (absolute index difference >= 1)
574+
# np.PINF has been replaced by 1e12 here, as on some systems np.PINF becomes np.NINF
575+
upper_bound = np.min(
576+
hamming_index_difference,
577+
initial=1e12,
578+
axis=0,
579+
where=hamming_index_difference > 0,
580+
)
581+
582+
lower_bound = np.max(
583+
hamming_index_difference,
584+
initial=-1e12,
585+
axis=0,
586+
where=hamming_index_difference < 0,
587+
)
588+
589+
# return mask for adjacent neighbors (each parameter is within bounds)
590+
adjacent_mask = np.all((lower_bound <= hamming_index_difference) & (hamming_index_difference <= upper_bound), axis=1)
591+
592+
# return hamming neighbors that are also adjacent
593+
return hamming_indices[adjacent_mask]
594+
555595
def __get_neighbors_indices_strictlyadjacent(
556596
self, param_config_index: int = None, param_config: tuple = None
557597
) -> List[int]:
@@ -615,6 +655,13 @@ def __build_neighbors_index(self, neighbor_method) -> List[List[int]]:
615655
# for each parameter configuration, find the neighboring parameter configurations
616656
if self.params_values_indices is None:
617657
self.__prepare_neighbors_index()
658+
659+
if neighbor_method == "Hamming-adjacent":
660+
return list(
661+
self.__get_neighbors_indices_hammingadjacent(param_config_index, param_config)
662+
for param_config_index, param_config in enumerate(self.list)
663+
)
664+
618665
if neighbor_method == "strictly-adjacent":
619666
return list(
620667
self.__get_neighbors_indices_strictlyadjacent(param_config_index, param_config)
@@ -667,6 +714,8 @@ def get_neighbors_indices_no_cache(self, param_config: tuple, neighbor_method=No
667714
self.__prepare_neighbors_index()
668715

669716
# if the passed param_config is fictious, we can not use the pre-calculated neighbors index
717+
if neighbor_method == "Hamming-adjacent":
718+
return self.__get_neighbors_indices_hammingadjacent(param_config_index, param_config)
670719
if neighbor_method == "strictly-adjacent":
671720
return self.__get_neighbors_indices_strictlyadjacent(param_config_index, param_config)
672721
if neighbor_method == "adjacent":

test/test_searchspace.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,16 @@ def test_neighbors_hamming():
245245
__test_neighbors(test_config, expected_neighbors, "Hamming")
246246

247247

248+
def test_neighbors_hammingadjacent():
249+
"""Test whether the Hamming-adjacent neighbors are as expected."""
250+
test_config = tuple([1, 4, "string_1"])
251+
expected_neighbors = [
252+
(1.5, 4, 'string_1'),
253+
]
254+
255+
__test_neighbors(test_config, expected_neighbors, "Hamming-adjacent")
256+
257+
248258
def test_neighbors_strictlyadjacent():
249259
"""Test whether the strictly adjacent neighbors are as expected."""
250260
test_config = tuple([1, 4, "string_1"])
@@ -274,11 +284,19 @@ def test_neighbors_adjacent():
274284
def test_neighbors_fictious():
275285
"""Test whether the neighbors are as expected for a fictious parameter configuration (i.e. not existing in the search space due to restrictions)."""
276286
test_config = tuple([1.5, 4, "string_1"])
287+
277288
expected_neighbors_hamming = [
278289
(1.5, 4, 'string_2'),
279290
(1.5, 5.5, 'string_1'),
280291
(3, 4, 'string_1'),
281292
]
293+
294+
expected_neighbors_hammingadjacent = [
295+
(1.5, 4, 'string_2'),
296+
(1.5, 5.5, 'string_1'),
297+
(3, 4, 'string_1'),
298+
]
299+
282300
expected_neighbors_strictlyadjacent = [
283301
(1.5, 5.5, 'string_2'),
284302
(1.5, 5.5, 'string_1'),
@@ -294,6 +312,7 @@ def test_neighbors_fictious():
294312
]
295313

296314
__test_neighbors_direct(test_config, expected_neighbors_hamming, "Hamming")
315+
__test_neighbors_direct(test_config, expected_neighbors_hammingadjacent, "Hamming-adjacent")
297316
__test_neighbors_direct(test_config, expected_neighbors_strictlyadjacent, "strictly-adjacent")
298317
__test_neighbors_direct(test_config, expected_neighbors_adjacent, "adjacent")
299318

0 commit comments

Comments
 (0)