7
7
from kernel_tuner .strategies .common import CostFunc
8
8
from kernel_tuner .util import StopCriterionReached
9
9
10
- supported_searchtechniques = ["auc_bandit" , "differential_evolution" , "pattern_search" , "round_robin" , "simulated_annealing" ]
10
+ supported_searchtechniques = ["auc_bandit" , "differential_evolution" , "pattern_search" , "round_robin" , "simulated_annealing" , "torczon" ]
11
11
12
12
_options = dict (searchtechnique = (f"PyATF optimization algorithm to use, choose any from { supported_searchtechniques } " , "simulated_annealing" ))
13
13
14
14
def tune (searchspace : Searchspace , runner , tuning_options ):
15
15
from pyatf .search_techniques .search_technique import SearchTechnique
16
+ from pyatf .search_space import SearchSpace as pyATFSearchSpace
16
17
17
18
# setup the Kernel Tuner functionalities
18
- cost_func = CostFunc (searchspace , tuning_options , runner , scaling = True , snap = True , return_invalid = True )
19
- # using this instead of get_bounds because scaling is used
20
- bounds , _ , eps = cost_func .get_bounds_x0_eps ()
19
+ cost_func = CostFunc (searchspace , tuning_options , runner , scaling = False , snap = False , return_invalid = False )
21
20
22
21
# dynamically import the search technique based on the provided options
23
22
module_name , = common .get_options (tuning_options .strategy_options , _options )
@@ -31,9 +30,17 @@ def tune(searchspace: Searchspace, runner, tuning_options):
31
30
assert isinstance (search_technique , SearchTechnique ), f"Search technique { search_technique } is not a valid pyATF search technique."
32
31
33
32
# initialize the search space
34
- # from pyatf.search_space import SearchSpace as PyATFSearchSpace
35
- # assert searchspace.tune_params_pyatf is not None
36
- # search_space = PyATFSearchSpace(*searchspace.tune_params_pyatf, enable_1d_access=False) # SearchTechnique1D currently not supported
33
+ searchspace_pyatf = Searchspace (
34
+ searchspace .tune_params ,
35
+ tuning_options .restrictions_unmodified ,
36
+ searchspace .max_threads ,
37
+ searchspace .block_size_names ,
38
+ defer_construction = True ,
39
+ framework = "pyatf"
40
+ )
41
+ tune_params_pyatf = searchspace_pyatf .get_tune_params_pyatf ()
42
+ assert isinstance (tune_params_pyatf , (tuple , list )), f"Tuning parameters must be a tuple or list of tuples, is { type (tune_params_pyatf )} ({ tune_params_pyatf } )."
43
+ search_space_pyatf = pyATFSearchSpace (* tune_params_pyatf , enable_1d_access = False ) # SearchTechnique1D currently not supported
37
44
38
45
# initialize
39
46
get_next_coordinates_or_indices = search_technique .get_next_coordinates
@@ -54,16 +61,13 @@ def tune(searchspace: Searchspace, runner, tuning_options):
54
61
55
62
# get configuration
56
63
coords_or_index = coordinates_or_indices .pop ()
57
- # config = search_space .get_configuration(coords_or_index)
64
+ config = search_space_pyatf .get_configuration (coords_or_index )
58
65
valid = True
59
66
cost = None
60
67
61
- # convert normalized coordinates of each parameter to range of bounds (from [0, 1] to [bound[0], bound[1]])
62
- if isinstance (coords_or_index , tuple ):
63
- coords_or_index = tuple (b [0 ]+ c * (b [1 ]- b [0 ]) for c , b in zip (coords_or_index , bounds ) if c is not None )
64
-
65
68
# evaluate the configuration
66
- opt_result = cost_func (coords_or_index )
69
+ x = tuple ([config [k ] for k in searchspace .tune_params .keys ()])
70
+ opt_result = cost_func (x , check_restrictions = False )
67
71
68
72
# adjust opt_result to expected PyATF output in cost and valid
69
73
if not isinstance (opt_result , (int , float )):
@@ -81,22 +85,6 @@ def tune(searchspace: Searchspace, runner, tuning_options):
81
85
82
86
return cost_func .results
83
87
84
- # scale variables in x to make 'eps' relevant for multiple variables
85
- cost_func = CostFunc (searchspace , tuning_options , runner , scaling = True )
86
-
87
- opt_result = None
88
- try :
89
- opt_result = searchtechnique (cost_func )
90
- except StopCriterionReached as e :
91
- searchtechnique .finalize ()
92
- if tuning_options .verbose :
93
- print (e )
94
-
95
- if opt_result and tuning_options .verbose :
96
- print (opt_result .message )
97
-
98
- return cost_func .results
99
-
100
88
101
89
# class TuningRun:
102
90
# def __init__(self,
0 commit comments