Skip to content

Commit 632df23

Browse files
committed
Add space_constraint option for skopt strategy
1 parent a0a6c8d commit 632df23

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

kernel_tuner/strategies/skopt.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,23 @@
55
from kernel_tuner.strategies.common import (
66
CostFunc,
77
get_options,
8-
scale_from_params,
8+
snap_to_nearest_config,
99
get_strategy_docstring,
1010
)
1111

1212
supported_methods = ["forest", "gbrt", "gp", "dummy"]
1313

1414
_options = dict(
15-
method=(f"Local optimization algorithm to use, choose any from {supported_methods}", "gp"),
16-
options=("Options passed to the skopt method as kwargs.", dict()),
17-
popsize=("Number of initial samples. If `None`, let skopt choose the initial population", None),
18-
maxiter=("Maximum number of times to repeat the method until the budget is exhausted.", 1),
15+
method=(f"Local optimization algorithm to use, choose any from {supported_methods}", "gp"),
16+
options=("Options passed to the skopt method as kwargs.", dict()),
17+
popsize=("Number of initial samples. If `None`, let skopt choose the initial population", None),
18+
maxiter=("Maximum number of times to repeat the method until the budget is exhausted.", 1),
1919
)
2020

21+
2122
def tune(searchspace: Searchspace, runner, tuning_options):
2223
import skopt
2324

24-
cost_func = CostFunc(searchspace, tuning_options, runner, scaling=True, invalid_value=1e9)
25-
bounds, _, eps = cost_func.get_bounds_x0_eps()
26-
2725
method, skopt_options, popsize, maxiter = get_options(tuning_options.strategy_options, _options)
2826

2927
# Get maximum number of evaluations
@@ -32,8 +30,8 @@ def tune(searchspace: Searchspace, runner, tuning_options):
3230
max_fevals = min(tuning_options["max_fevals"], max_fevals)
3331

3432
# Set the maximum number of calls to 100 times the maximum number of evaluations.
35-
# Not all calls by skopt will result in an evaluation, due to restrictions or
36-
# since different calls might map to the same configuration.
33+
# Not all calls by skopt will result in an evaluation since different calls might
34+
# map to the same configuration.
3735
if "n_calls" not in skopt_options:
3836
skopt_options["n_calls"] = 100 * max_fevals
3937

@@ -42,29 +40,37 @@ def tune(searchspace: Searchspace, runner, tuning_options):
4240
# the samples as it is not aware of restrictions.
4341
if popsize:
4442
x0 = searchspace.get_random_sample(min(popsize, max_fevals))
45-
skopt_options["x0"] = [list(scale_from_params(x, searchspace.tune_params, eps)) for x in x0]
46-
43+
skopt_options["x0"] = [searchspace.get_param_indices(x) for x in x0]
4744

4845
opt_result = None
46+
tune_params_values = list(searchspace.tune_params.values())
47+
bounds = [(0, len(p) - 1) if len(p) > 1 else [0] for p in tune_params_values]
48+
49+
cost_func = CostFunc(searchspace, tuning_options, runner)
50+
objective = lambda x: cost_func(searchspace.get_param_config_from_param_indices(x))
51+
space_constraint = lambda x: searchspace.is_param_config_valid(searchspace.get_param_config_from_param_indices(x))
52+
53+
skopt_options["space_constraint"] = space_constraint
54+
skopt_options["verbose"] = tuning_options.verbose
4955

5056
try:
5157
for _ in range(maxiter):
5258
if method == "dummy":
53-
opt_result = skopt.dummy_minimize(cost_func, bounds, **skopt_options)
59+
opt_result = skopt.dummy_minimize(objective, bounds, **skopt_options)
5460
elif method == "forest":
55-
opt_result = skopt.forest_minimize(cost_func, bounds, **skopt_options)
61+
opt_result = skopt.forest_minimize(objective, bounds, **skopt_options)
5662
elif method == "gp":
57-
opt_result = skopt.gp_minimize(cost_func, bounds, **skopt_options)
63+
opt_result = skopt.gp_minimize(objective, bounds, **skopt_options)
5864
elif method == "gbrt":
59-
opt_result = skopt.gbrt_minimize(cost_func, bounds, **skopt_options)
65+
opt_result = skopt.gbrt_minimize(objective, bounds, **skopt_options)
6066
else:
6167
raise ValueError(f"invalid skopt method: {method}")
6268
except StopCriterionReached as e:
6369
if tuning_options.verbose:
6470
print(e)
6571

6672
if opt_result and tuning_options.verbose:
67-
print(opt_result.message)
73+
print(opt_result)
6874

6975
return cost_func.results
7076

0 commit comments

Comments
 (0)