24
24
except ImportError :
25
25
bayes_opt_present = False
26
26
27
- from kernel_tuner import util
28
-
29
27
supported_methods = ["poi" , "ei" , "lcb" , "lcb-srinivas" , "multi" , "multi-advanced" , "multi-fast" , "multi-ultrafast" ]
30
28
31
29
@@ -107,19 +105,8 @@ def tune(searchspace: Searchspace, runner, tuning_options):
107
105
_ , _ , eps = cost_func .get_bounds_x0_eps ()
108
106
109
107
# compute cartesian product of all tunable parameters
110
- parameter_space = itertools .product (* tune_params .values ())
111
-
112
- # check for search space restrictions
113
- if searchspace .restrictions is not None :
114
- tuning_options .verbose = False
115
- parameter_space = filter (lambda p : util .config_valid (p , tuning_options , runner .dev .max_threads ), parameter_space )
116
- parameter_space = list (parameter_space )
117
- if len (parameter_space ) < 1 :
118
- raise ValueError ("Empty parameterspace after restrictionscheck. Restrictionscheck is possibly too strict." )
119
- if len (parameter_space ) == 1 :
120
- raise ValueError (
121
- f"Only one configuration after restrictionscheck. Restrictionscheck is possibly too strict. Configuration: { parameter_space [0 ]} "
122
- )
108
+ # TODO actually use the Searchspace object properly throughout Bayesian Optimization
109
+ parameter_space = searchspace .list
123
110
124
111
# normalize search space to [0,1]
125
112
normalize_dict , denormalize_dict = generate_normalized_param_dicts (tune_params , eps )
@@ -137,7 +124,7 @@ def tune(searchspace: Searchspace, runner, tuning_options):
137
124
# initialize and optimize
138
125
try :
139
126
bo = BayesianOptimization (
140
- parameter_space , removed_tune_params , tuning_options , normalize_dict , denormalize_dict , cost_func
127
+ parameter_space , searchspace , removed_tune_params , tuning_options , normalize_dict , denormalize_dict , cost_func
141
128
)
142
129
except StopCriterionReached :
143
130
warnings .warn (
@@ -179,6 +166,7 @@ class BayesianOptimization:
179
166
def __init__ (
180
167
self ,
181
168
searchspace : list ,
169
+ searchspace_obj : Searchspace ,
182
170
removed_tune_params : list ,
183
171
tuning_options : dict ,
184
172
normalize_dict : dict ,
@@ -256,6 +244,7 @@ def get_hyperparam(name: str, default, supported_values=list()):
256
244
257
245
# set remaining values
258
246
self .__searchspace = searchspace
247
+ self .__searchspace_obj = searchspace_obj
259
248
self .removed_tune_params = removed_tune_params
260
249
self .searchspace_size = len (self .searchspace )
261
250
self .num_dimensions = len (self .dimensions ())
@@ -463,7 +452,7 @@ def evaluate_objective_function(self, param_config: tuple) -> float:
463
452
"""Evaluates the objective function."""
464
453
param_config = self .unprune_param_config (param_config )
465
454
denormalized_param_config = self .denormalize_param_config (param_config )
466
- if not util . config_valid ( denormalized_param_config , self .tuning_options , self . max_threads ):
455
+ if not self .__searchspace_obj . is_param_config_valid ( denormalized_param_config ):
467
456
return self .invalid_value
468
457
val = self .cost_func (param_config )
469
458
self .fevals += 1
0 commit comments