Skip to content

Commit ad88894

Browse files
committed
Extra check to make sure non-x0 supporting strategies raise an exception, implemented unsupported options check in Bayesian Optimization
1 parent f0afa0b commit ad88894

File tree

2 files changed

+30
-25
lines changed

2 files changed

+30
-25
lines changed

kernel_tuner/strategies/bayes_opt.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
# BO imports
1515
from kernel_tuner.searchspace import Searchspace
16-
from kernel_tuner.strategies.common import CostFunc
16+
from kernel_tuner.strategies.common import CostFunc, get_options
1717
from kernel_tuner.util import StopCriterionReached
1818

1919
try:
@@ -26,6 +26,24 @@
2626

2727
supported_methods = ["poi", "ei", "lcb", "lcb-srinivas", "multi", "multi-advanced", "multi-fast", "multi-ultrafast"]
2828

29+
# _options dict is used for generating documentation, but is not used to check for unsupported strategy_options in bayes_opt
30+
_options = dict(
31+
covariancekernel=(
32+
'The Covariance kernel to use, choose any from "constantrbf", "rbf", "matern32", "matern52"',
33+
"matern32",
34+
),
35+
covariancelengthscale=("The covariance length scale", 1.5),
36+
method=(
37+
"The Bayesian Optimization method to use, choose any from " + ", ".join(supported_methods),
38+
"multi-ultrafast",
39+
),
40+
samplingmethod=(
41+
"Method used for initial sampling the parameter space, either random or Latin Hypercube Sampling (LHS)",
42+
"lhs",
43+
),
44+
popsize=("Number of initial samples", 20),
45+
)
46+
2947

3048
def generate_normalized_param_dicts(tune_params: dict, eps: float) -> Tuple[dict, dict]:
3149
"""Generates normalization and denormalization dictionaries."""
@@ -92,6 +110,9 @@ def tune(searchspace: Searchspace, runner, tuning_options):
92110
:rtype: list(dict()), dict()
93111
94112
"""
113+
# we don't actually use this for Bayesian Optimization, but it is used to check for unsupported options
114+
get_options(tuning_options.strategy_options, _options, unsupported=["x0"])
115+
95116
max_fevals = tuning_options.strategy_options.get("max_fevals", 100)
96117
prune_parameterspace = tuning_options.strategy_options.get("pruneparameterspace", True)
97118
if not bayes_opt_present:
@@ -143,25 +164,6 @@ def tune(searchspace: Searchspace, runner, tuning_options):
143164
return cost_func.results
144165

145166

146-
# _options dict is used for generating documentation, but is not used to check for unsupported strategy_options in bayes_opt
147-
_options = dict(
148-
covariancekernel=(
149-
'The Covariance kernel to use, choose any from "constantrbf", "rbf", "matern32", "matern52"',
150-
"matern32",
151-
),
152-
covariancelengthscale=("The covariance length scale", 1.5),
153-
method=(
154-
"The Bayesian Optimization method to use, choose any from " + ", ".join(supported_methods),
155-
"multi-ultrafast",
156-
),
157-
samplingmethod=(
158-
"Method used for initial sampling the parameter space, either random or Latin Hypercube Sampling (LHS)",
159-
"lhs",
160-
),
161-
popsize=("Number of initial samples", 20),
162-
)
163-
164-
165167
class BayesianOptimization:
166168
def __init__(
167169
self,

test/strategies/test_strategies.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,17 @@ def test_strategies(vector_add, strategy):
8282
assert isinstance(res[expected_key], expected_type)
8383

8484
# check if strategy respects user-specified starting point (x0)
85+
x0 = [256]
86+
filter_options["x0"] = x0
8587
if not strategy in ["brute_force", "random_sample", "bayes_opt"]:
86-
x0 = [256]
87-
filter_options["x0"] = x0
88-
8988
results, _ = kernel_tuner.tune_kernel(*vector_add, strategy=strategy, strategy_options=filter_options,
90-
verbose=False, cache=cache_filename, simulation_mode=True)
91-
89+
verbose=False, cache=cache_filename, simulation_mode=True)
9290
assert results[0]["block_size_x"] == x0[0]
91+
else:
92+
with pytest.raises(ValueError):
93+
results, _ = kernel_tuner.tune_kernel(*vector_add, strategy=strategy, strategy_options=filter_options,
94+
verbose=False, cache=cache_filename, simulation_mode=True)
95+
9396

9497

9598

0 commit comments

Comments
 (0)