Skip to content

Commit 5c94c78

Browse files
committed
Basic implementation to use Searchspace in Bayesian Optimization
1 parent ecf7218 commit 5c94c78

File tree

2 files changed

+7
-18
lines changed

2 files changed

+7
-18
lines changed

kernel_tuner/strategies/bayes_opt.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
except ImportError:
2525
bayes_opt_present = False
2626

27-
from kernel_tuner import util
28-
2927
supported_methods = ["poi", "ei", "lcb", "lcb-srinivas", "multi", "multi-advanced", "multi-fast", "multi-ultrafast"]
3028

3129

@@ -107,19 +105,8 @@ def tune(searchspace: Searchspace, runner, tuning_options):
107105
_, _, eps = cost_func.get_bounds_x0_eps()
108106

109107
# compute cartesian product of all tunable parameters
110-
parameter_space = itertools.product(*tune_params.values())
111-
112-
# check for search space restrictions
113-
if searchspace.restrictions is not None:
114-
tuning_options.verbose = False
115-
parameter_space = filter(lambda p: util.config_valid(p, tuning_options, runner.dev.max_threads), parameter_space)
116-
parameter_space = list(parameter_space)
117-
if len(parameter_space) < 1:
118-
raise ValueError("Empty parameterspace after restrictionscheck. Restrictionscheck is possibly too strict.")
119-
if len(parameter_space) == 1:
120-
raise ValueError(
121-
f"Only one configuration after restrictionscheck. Restrictionscheck is possibly too strict. Configuration: {parameter_space[0]}"
122-
)
108+
# TODO actually use the Searchspace object properly throughout Bayesian Optimization
109+
parameter_space = searchspace.list
123110

124111
# normalize search space to [0,1]
125112
normalize_dict, denormalize_dict = generate_normalized_param_dicts(tune_params, eps)
@@ -137,7 +124,7 @@ def tune(searchspace: Searchspace, runner, tuning_options):
137124
# initialize and optimize
138125
try:
139126
bo = BayesianOptimization(
140-
parameter_space, removed_tune_params, tuning_options, normalize_dict, denormalize_dict, cost_func
127+
parameter_space, searchspace, removed_tune_params, tuning_options, normalize_dict, denormalize_dict, cost_func
141128
)
142129
except StopCriterionReached:
143130
warnings.warn(
@@ -179,6 +166,7 @@ class BayesianOptimization:
179166
def __init__(
180167
self,
181168
searchspace: list,
169+
searchspace_obj: Searchspace,
182170
removed_tune_params: list,
183171
tuning_options: dict,
184172
normalize_dict: dict,
@@ -256,6 +244,7 @@ def get_hyperparam(name: str, default, supported_values=list()):
256244

257245
# set remaining values
258246
self.__searchspace = searchspace
247+
self.__searchspace_obj = searchspace_obj
259248
self.removed_tune_params = removed_tune_params
260249
self.searchspace_size = len(self.searchspace)
261250
self.num_dimensions = len(self.dimensions())
@@ -463,7 +452,7 @@ def evaluate_objective_function(self, param_config: tuple) -> float:
463452
"""Evaluates the objective function."""
464453
param_config = self.unprune_param_config(param_config)
465454
denormalized_param_config = self.denormalize_param_config(param_config)
466-
if not util.config_valid(denormalized_param_config, self.tuning_options, self.max_threads):
455+
if not self.__searchspace_obj.is_param_config_valid(denormalized_param_config):
467456
return self.invalid_value
468457
val = self.cost_func(param_config)
469458
self.fevals += 1

test/strategies/test_bayesian_optimization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
pruned_parameter_space, removed_tune_params = bayes_opt.prune_parameter_space(normalized_parameter_space, tuning_options, tune_params, original_to_normalized)
3838

3939
# initialize BO
40-
BO = BayesianOptimization(pruned_parameter_space, removed_tune_params, tuning_options, original_to_normalized, normalized_to_original, cost_func)
40+
BO = BayesianOptimization(pruned_parameter_space, searchspace, removed_tune_params, tuning_options, original_to_normalized, normalized_to_original, cost_func)
4141
predictions, _, std = BO.predict_list(BO.unvisited_cache)
4242

4343

0 commit comments

Comments
 (0)