55from kernel_tuner .strategies .common import (
66 CostFunc ,
77 get_options ,
8- scale_from_params ,
8+ snap_to_nearest_config ,
99 get_strategy_docstring ,
1010)
1111
1212supported_methods = ["forest" , "gbrt" , "gp" , "dummy" ]
1313
1414_options = dict (
15- method = (f"Local optimization algorithm to use, choose any from { supported_methods } " , "gp" ),
16- options = ("Options passed to the skopt method as kwargs." , dict ()),
17- popsize = ("Number of initial samples. If `None`, let skopt choose the initial population" , None ),
18- maxiter = ("Maximum number of times to repeat the method until the budget is exhausted." , 1 ),
15+ method = (f"Local optimization algorithm to use, choose any from { supported_methods } " , "gp" ),
16+ options = ("Options passed to the skopt method as kwargs." , dict ()),
17+ popsize = ("Number of initial samples. If `None`, let skopt choose the initial population" , None ),
18+ maxiter = ("Maximum number of times to repeat the method until the budget is exhausted." , 1 ),
1919)
2020
21+
2122def tune (searchspace : Searchspace , runner , tuning_options ):
2223 import skopt
2324
24- cost_func = CostFunc (searchspace , tuning_options , runner , scaling = True , invalid_value = 1e9 )
25- bounds , _ , eps = cost_func .get_bounds_x0_eps ()
26-
2725 method , skopt_options , popsize , maxiter = get_options (tuning_options .strategy_options , _options )
2826
2927 # Get maximum number of evaluations
@@ -32,8 +30,8 @@ def tune(searchspace: Searchspace, runner, tuning_options):
3230 max_fevals = min (tuning_options ["max_fevals" ], max_fevals )
3331
3432 # Set the maximum number of calls to 100 times the maximum number of evaluations.
35- # Not all calls by skopt will result in an evaluation, due to restrictions or
36- # since different calls might map to the same configuration.
33+ # Not all calls by skopt will result in an evaluation since different calls might
34+ # map to the same configuration.
3735 if "n_calls" not in skopt_options :
3836 skopt_options ["n_calls" ] = 100 * max_fevals
3937
@@ -42,29 +40,37 @@ def tune(searchspace: Searchspace, runner, tuning_options):
4240 # the samples as it is not aware of restrictions.
4341 if popsize :
4442 x0 = searchspace .get_random_sample (min (popsize , max_fevals ))
45- skopt_options ["x0" ] = [list (scale_from_params (x , searchspace .tune_params , eps )) for x in x0 ]
46-
43+ skopt_options ["x0" ] = [searchspace .get_param_indices (x ) for x in x0 ]
4744
4845 opt_result = None
46+ tune_params_values = list (searchspace .tune_params .values ())
47+ bounds = [(0 , len (p ) - 1 ) if len (p ) > 1 else [0 ] for p in tune_params_values ]
48+
49+ cost_func = CostFunc (searchspace , tuning_options , runner )
50+ objective = lambda x : cost_func (searchspace .get_param_config_from_param_indices (x ))
51+ space_constraint = lambda x : searchspace .is_param_config_valid (searchspace .get_param_config_from_param_indices (x ))
52+
53+ skopt_options ["space_constraint" ] = space_constraint
54+ skopt_options ["verbose" ] = tuning_options .verbose
4955
5056 try :
5157 for _ in range (maxiter ):
5258 if method == "dummy" :
53- opt_result = skopt .dummy_minimize (cost_func , bounds , ** skopt_options )
59+ opt_result = skopt .dummy_minimize (objective , bounds , ** skopt_options )
5460 elif method == "forest" :
55- opt_result = skopt .forest_minimize (cost_func , bounds , ** skopt_options )
61+ opt_result = skopt .forest_minimize (objective , bounds , ** skopt_options )
5662 elif method == "gp" :
57- opt_result = skopt .gp_minimize (cost_func , bounds , ** skopt_options )
63+ opt_result = skopt .gp_minimize (objective , bounds , ** skopt_options )
5864 elif method == "gbrt" :
59- opt_result = skopt .gbrt_minimize (cost_func , bounds , ** skopt_options )
65+ opt_result = skopt .gbrt_minimize (objective , bounds , ** skopt_options )
6066 else :
6167 raise ValueError (f"invalid skopt method: { method } " )
6268 except StopCriterionReached as e :
6369 if tuning_options .verbose :
6470 print (e )
6571
6672 if opt_result and tuning_options .verbose :
67- print (opt_result . message )
73+ print (opt_result )
6874
6975 return cost_func .results
7076
0 commit comments