Skip to content

Commit 36208d1

Browse files
committed
Implemented pyATF search space lookup of configs
1 parent 49e786d commit 36208d1

File tree

1 file changed

+17
-29
lines changed

1 file changed

+17
-29
lines changed

kernel_tuner/strategies/pyatf_strategies.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,16 @@
77
from kernel_tuner.strategies.common import CostFunc
88
from kernel_tuner.util import StopCriterionReached
99

10-
supported_searchtechniques = ["auc_bandit", "differential_evolution", "pattern_search", "round_robin", "simulated_annealing"]
10+
supported_searchtechniques = ["auc_bandit", "differential_evolution", "pattern_search", "round_robin", "simulated_annealing", "torczon"]
1111

1212
_options = dict(searchtechnique=(f"PyATF optimization algorithm to use, choose any from {supported_searchtechniques}", "simulated_annealing"))
1313

1414
def tune(searchspace: Searchspace, runner, tuning_options):
1515
from pyatf.search_techniques.search_technique import SearchTechnique
16+
from pyatf.search_space import SearchSpace as pyATFSearchSpace
1617

1718
# setup the Kernel Tuner functionalities
18-
cost_func = CostFunc(searchspace, tuning_options, runner, scaling=True, snap=True, return_invalid=True)
19-
# using this instead of get_bounds because scaling is used
20-
bounds, _, eps = cost_func.get_bounds_x0_eps()
19+
cost_func = CostFunc(searchspace, tuning_options, runner, scaling=False, snap=False, return_invalid=False)
2120

2221
# dynamically import the search technique based on the provided options
2322
module_name, = common.get_options(tuning_options.strategy_options, _options)
@@ -31,9 +30,17 @@ def tune(searchspace: Searchspace, runner, tuning_options):
3130
assert isinstance(search_technique, SearchTechnique), f"Search technique {search_technique} is not a valid pyATF search technique."
3231

3332
# initialize the search space
34-
# from pyatf.search_space import SearchSpace as PyATFSearchSpace
35-
# assert searchspace.tune_params_pyatf is not None
36-
# search_space = PyATFSearchSpace(*searchspace.tune_params_pyatf, enable_1d_access=False) # SearchTechnique1D currently not supported
33+
searchspace_pyatf = Searchspace(
34+
searchspace.tune_params,
35+
tuning_options.restrictions_unmodified,
36+
searchspace.max_threads,
37+
searchspace.block_size_names,
38+
defer_construction=True,
39+
framework="pyatf"
40+
)
41+
tune_params_pyatf = searchspace_pyatf.get_tune_params_pyatf()
42+
assert isinstance(tune_params_pyatf, (tuple, list)), f"Tuning parameters must be a tuple or list of tuples, is {type(tune_params_pyatf)} ({tune_params_pyatf})."
43+
search_space_pyatf = pyATFSearchSpace(*tune_params_pyatf, enable_1d_access=False) # SearchTechnique1D currently not supported
3744

3845
# initialize
3946
get_next_coordinates_or_indices = search_technique.get_next_coordinates
@@ -54,16 +61,13 @@ def tune(searchspace: Searchspace, runner, tuning_options):
5461

5562
# get configuration
5663
coords_or_index = coordinates_or_indices.pop()
57-
# config = search_space.get_configuration(coords_or_index)
64+
config = search_space_pyatf.get_configuration(coords_or_index)
5865
valid = True
5966
cost = None
6067

61-
# convert normalized coordinates of each parameter to range of bounds (from [0, 1] to [bound[0], bound[1]])
62-
if isinstance(coords_or_index, tuple):
63-
coords_or_index = tuple(b[0]+c*(b[1]-b[0]) for c, b in zip(coords_or_index, bounds) if c is not None)
64-
6568
# evaluate the configuration
66-
opt_result = cost_func(coords_or_index)
69+
x = tuple([config[k] for k in searchspace.tune_params.keys()])
70+
opt_result = cost_func(x, check_restrictions=False)
6771

6872
# adjust opt_result to expected PyATF output in cost and valid
6973
if not isinstance(opt_result, (int, float)):
@@ -81,22 +85,6 @@ def tune(searchspace: Searchspace, runner, tuning_options):
8185

8286
return cost_func.results
8387

84-
# scale variables in x to make 'eps' relevant for multiple variables
85-
cost_func = CostFunc(searchspace, tuning_options, runner, scaling=True)
86-
87-
opt_result = None
88-
try:
89-
opt_result = searchtechnique(cost_func)
90-
except StopCriterionReached as e:
91-
searchtechnique.finalize()
92-
if tuning_options.verbose:
93-
print(e)
94-
95-
if opt_result and tuning_options.verbose:
96-
print(opt_result.message)
97-
98-
return cost_func.results
99-
10088

10189
# class TuningRun:
10290
# def __init__(self,

0 commit comments

Comments
 (0)