1
1
"""Strategy that dynamically imports and enables the use of pyATF strategies."""
2
2
3
3
from importlib import import_module
4
+ import zlib
5
+ from pathlib import Path
4
6
5
7
from kernel_tuner .searchspace import Searchspace
6
8
from kernel_tuner .strategies import common
11
13
12
14
_options = dict (searchtechnique = (f"PyATF optimization algorithm to use, choose any from { supported_searchtechniques } " , "simulated_annealing" ))
13
15
16
+ def get_cache_checksum (d : dict ):
17
+ checksum = 0
18
+ for item in d .items ():
19
+ c1 = 1
20
+ for t in item :
21
+ c1 = zlib .adler32 (bytes (repr (t ),'utf-8' ), c1 )
22
+ checksum = checksum ^ c1
23
+ return checksum
24
+
14
25
def tune (searchspace : Searchspace , runner , tuning_options ):
15
26
from pyatf .search_techniques .search_technique import SearchTechnique
16
27
from pyatf .search_space import SearchSpace as pyATFSearchSpace
28
+ from pyatf import TP
29
+ try :
30
+ import dill
31
+ pyatf_search_space_caching = True
32
+ except ImportError :
33
+ from warnings import warn
34
+ pyatf_search_space_caching = False
35
+ warn ("dill is not installed, pyATF search space caching will not be used." )
17
36
18
37
# setup the Kernel Tuner functionalities
19
38
cost_func = CostFunc (searchspace , tuning_options , runner , scaling = False , snap = False , return_invalid = False )
@@ -29,18 +48,29 @@ def tune(searchspace: Searchspace, runner, tuning_options):
29
48
search_technique .initialize (len (searchspace .param_names ))
30
49
assert isinstance (search_technique , SearchTechnique ), f"Search technique { search_technique } is not a valid pyATF search technique."
31
50
51
+ # get the search space hash
52
+ tune_params_hashable = {k : "," .join ([str (i ) for i in v ]) if isinstance (v , (list , tuple )) else v for k , v in searchspace .tune_params .items ()}
53
+ searchspace_caches_folder = Path ("./pyatf_searchspace_caches" )
54
+ searchspace_caches_folder .mkdir (parents = True , exist_ok = True )
55
+ searchspace_cache_path = searchspace_caches_folder / Path (f"pyatf_searchspace_cache_{ get_cache_checksum (tune_params_hashable )} .pkl" )
56
+
32
57
# initialize the search space
33
- searchspace_pyatf = Searchspace (
34
- searchspace .tune_params ,
35
- tuning_options .restrictions_unmodified ,
36
- searchspace .max_threads ,
37
- searchspace .block_size_names ,
38
- defer_construction = True ,
39
- framework = "pyatf"
40
- )
41
- tune_params_pyatf = searchspace_pyatf .get_tune_params_pyatf ()
42
- assert isinstance (tune_params_pyatf , (tuple , list )), f"Tuning parameters must be a tuple or list of tuples, is { type (tune_params_pyatf )} ({ tune_params_pyatf } )."
43
- search_space_pyatf = pyATFSearchSpace (* tune_params_pyatf , enable_1d_access = False ) # SearchTechnique1D currently not supported
58
+ if not pyatf_search_space_caching or not searchspace_cache_path .exists ():
59
+ searchspace_pyatf = Searchspace (
60
+ searchspace .tune_params ,
61
+ tuning_options .restrictions_unmodified ,
62
+ searchspace .max_threads ,
63
+ searchspace .block_size_names ,
64
+ defer_construction = True ,
65
+ framework = "pyatf"
66
+ )
67
+ tune_params_pyatf = searchspace_pyatf .get_tune_params_pyatf ()
68
+ assert isinstance (tune_params_pyatf , (tuple , list )), f"Tuning parameters must be a tuple or list of tuples, is { type (tune_params_pyatf )} ({ tune_params_pyatf } )."
69
+ search_space_pyatf = pyATFSearchSpace (* tune_params_pyatf , enable_1d_access = False ) # SearchTechnique1D currently not supported
70
+ if pyatf_search_space_caching :
71
+ dill .dump (search_space_pyatf , open (searchspace_cache_path , "wb" ))
72
+ elif searchspace_cache_path .exists ():
73
+ search_space_pyatf = dill .load (open (searchspace_cache_path , "rb" ))
44
74
45
75
# initialize
46
76
get_next_coordinates_or_indices = search_technique .get_next_coordinates
@@ -86,209 +116,4 @@ def tune(searchspace: Searchspace, runner, tuning_options):
86
116
return cost_func .results
87
117
88
118
89
- # class TuningRun:
90
- # def __init__(self,
91
- # search_space: SearchSpace | Tuple[TP, ...],
92
- # cost_function: CostFunction,
93
- # search_technique: Optional[Union[SearchTechnique, SearchTechnique1D]],
94
- # verbosity: Optional[int],
95
- # log_file: Optional[str],
96
- # abort_condition: Optional[AbortCondition]):
97
- # if search_space is None:
98
- # raise ValueError('missing call to `Tuner.tuning_parameters(...)`: no tuning parameters defined')
99
-
100
- # # tuning data
101
- # self._search_space: SearchSpace
102
- # self._search_technique: SearchTechnique | SearchTechnique1D
103
- # self._abort_condition: AbortCondition # TODO: does not work (add initialization)
104
- # self._tps: Tuple[TP, ...]
105
- # self._tuning_data: Optional[TuningData] = None
106
- # self._cost_function: CostFunction = cost_function
107
-
108
- # # progress data
109
- # self._verbosity = verbosity
110
- # self._log_file: Optional[TextIO] = None
111
- # self._last_log_dump: Optional[int] = None
112
- # self._last_line_length: Optional[int] = None
113
- # self._tuning_start_ns: Optional[int] = None
114
-
115
- # # prepare search technique
116
- # self._search_technique: SearchTechnique | SearchTechnique1D = search_technique
117
- # if self._search_technique is None:
118
- # self._search_technique = AUCBandit()
119
- # if isinstance(self._search_technique, SearchTechnique):
120
- # self._get_next_coordinates_or_indices = self._search_technique.get_next_coordinates
121
- # self._coordinates_or_index_param_name = 'search_space_coordinates'
122
- # else:
123
- # self._get_next_coordinates_or_indices = self._search_technique.get_next_indices
124
- # self._coordinates_or_index_param_name = 'search_space_index'
125
- # self._coordinates_or_indices: Set[Union[Coordinates, Index]] = set()
126
- # self._costs: Dict[Union[Coordinates, Index], Cost] = {}
127
-
128
- # # generate search space
129
- # if isinstance(search_space, SearchSpace):
130
- # self._search_space = search_space
131
- # else:
132
- # self._search_space = SearchSpace(*search_space,
133
- # enable_1d_access=isinstance(self._search_technique, SearchTechnique1D),
134
- # verbosity=verbosity)
135
- # self._tps = self._search_space.tps
136
- # self._search_space_generation_ns = self._search_space.generation_ns
137
- # if self._verbosity >= 2:
138
- # print(f'search space size: {self._search_space.constrained_size}')
139
-
140
- # # prepare abort condition
141
- # self._abort_condition = abort_condition
142
- # if self._abort_condition is None:
143
- # self._abort_condition = Evaluations(len(self._search_space))
144
-
145
- # # open log file
146
- # if log_file:
147
- # Path(log_file).parent.mkdir(parents=True, exist_ok=True)
148
- # self._log_file = open(log_file, 'w')
149
-
150
- # def __del__(self):
151
- # if self._log_file:
152
- # self._log_file.close()
153
-
154
- # @property
155
- # def cost_function(self):
156
- # return self._cost_function
157
-
158
- # @property
159
- # def abort_condition(self):
160
- # return self._abort_condition
161
-
162
- # @property
163
- # def tuning_data(self):
164
- # return self._tuning_data
165
-
166
- # def flush_log(self):
167
- # if self._log_file:
168
- # self._log_file.seek(0)
169
- # json.dump(self._tuning_data.to_json(), self._log_file, indent=4)
170
- # self._log_file.truncate()
171
- # self._last_log_dump = time.perf_counter_ns()
172
-
173
- # def _print_progress(self, timestamp: datetime, cost: Optional[Cost] = None):
174
- # now = time.perf_counter_ns()
175
- # elapsed_ns = now - self._tuning_start_ns
176
- # elapsed_seconds = elapsed_ns // 1000000000
177
- # elapsed_time_str = (f'{elapsed_seconds // 3600}'
178
- # f':{elapsed_seconds // 60 % 60:02d}'
179
- # f':{elapsed_seconds % 60:02d}')
180
- # progress = self._abort_condition.progress(self._tuning_data)
181
- # if self._verbosity >= 3:
182
- # line = (f'\r{timestamp.strftime("%Y-%m-%dT%H:%M:%S")}'
183
- # f' evaluations: {self._tuning_data.number_of_evaluated_configurations}'
184
- # f' (valid: {self._tuning_data.number_of_evaluated_valid_configurations})'
185
- # f', min. cost: {self._tuning_data.min_cost()}'
186
- # f', valid: {cost is not None}'
187
- # f', cost: {cost}')
188
- # line_length = len(line)
189
- # if line_length < self._last_line_length:
190
- # line += ' ' * (self._last_line_length - line_length)
191
- # print(line)
192
- # if progress is None:
193
- # spinner_char = ('-', '\\', '|', '/')[(elapsed_ns // 500000000) % 4]
194
- # line = f'\rTuning: {spinner_char} {elapsed_time_str}\r'
195
- # print(line, end='')
196
- # else:
197
- # if now > self._tuning_start_ns and progress > 0:
198
- # eta_seconds = ceil(((now - self._tuning_start_ns) / progress
199
- # * (1 - progress)) / 1000000000)
200
- # eta_str = (f'{eta_seconds // 3600}'
201
- # f':{eta_seconds // 60 % 60:02d}'
202
- # f':{eta_seconds % 60:02d}')
203
- # else:
204
- # eta_str = '?'
205
- # filled = '█' * floor(progress * 80)
206
- # empty = ' ' * ceil((1 - progress) * 80)
207
- # line = (f'\rexploring search space: |{filled}{empty}|'
208
- # f' {progress * 100:6.2f}% {elapsed_time_str} (ETA: {eta_str})')
209
- # print(line, end='')
210
- # self._last_line_length = len(line)
211
-
212
- # def initialize(self):
213
- # # reset progress data
214
- # self._tuning_start_ns = time.perf_counter_ns()
215
- # self._last_line_length = 0
216
-
217
- # # create tuning data
218
- # self._tuning_data = TuningData(list(tp.to_json() for tp in self._tps),
219
- # self._search_space.constrained_size,
220
- # self._search_space.unconstrained_size,
221
- # self._search_space_generation_ns,
222
- # self._search_technique.to_json(),
223
- # self._abort_condition.to_json())
224
-
225
- # # write tuning data
226
- # self.flush_log()
227
-
228
- # # initialize search technique
229
- # if isinstance(self._search_technique, SearchTechnique1D):
230
- # self._search_technique.initialize(len(self._search_space))
231
- # else:
232
- # self._search_technique.initialize(self._search_space.num_tps)
233
-
234
- # def make_step(self):
235
- # # get new coordinates
236
- # if not self._coordinates_or_indices:
237
- # if self._costs:
238
- # self._search_technique.report_costs(self._costs)
239
- # self._costs.clear()
240
- # self._coordinates_or_indices.update(self._get_next_coordinates_or_indices())
241
-
242
- # # get configuration
243
- # coords_or_index = self._coordinates_or_indices.pop()
244
- # config = self._search_space.get_configuration(coords_or_index)
245
-
246
- # # run cost function
247
- # valid = True
248
- # try:
249
- # cost = self._cost_function(config)
250
- # except CostFunctionError as e:
251
- # if self._verbosity >= 3:
252
- # print('\r' + ' ' * self._last_line_length + '\r', end='')
253
- # print('Error raised: ' + e.message)
254
- # self._last_line_length = 0
255
- # cost = None
256
- # valid = False
257
- # except BaseException as e:
258
- # self._tuning_data.record_evaluation(config, False, None, **{
259
- # self._coordinates_or_index_param_name: coords_or_index
260
- # })
261
- # self.flush_log()
262
- # raise e
263
- # timestamp = self._tuning_data.record_evaluation(config, valid, cost, **{
264
- # self._coordinates_or_index_param_name: coords_or_index
265
- # })
266
- # self._costs[coords_or_index] = cost
267
-
268
- # # print progress and dump log file (at most once every 5 minutes)
269
- # if self._verbosity >= 1:
270
- # self._print_progress(timestamp, cost)
271
- # if self._log_file and (self._last_log_dump is None or time.perf_counter_ns() - self._last_log_dump > 3e11):
272
- # self.flush_log()
273
-
274
- # def finalize(self, sigint_received: bool = False):
275
- # self._search_technique.finalize()
276
- # self._tuning_data.record_tuning_finished(sigint_received)
277
-
278
- # # write tuning data to file
279
- # if self._log_file:
280
- # self.flush_log()
281
- # self._log_file.close()
282
- # self._log_file = None
283
-
284
- # if self._verbosity >= 1:
285
- # print('\nfinished tuning')
286
- # if self._verbosity >= 2:
287
- # if self._tuning_data.min_cost() is not None:
288
- # print('best configuration:')
289
- # for tp_name, tp_value in self._tuning_data.configuration_of_min_cost().items():
290
- # print(f' {tp_name} = {tp_value}')
291
- # print(f'min cost: {self._tuning_data.min_cost()}')
292
-
293
-
294
119
tune .__doc__ = common .get_strategy_docstring ("pyatf_strategies" , _options )
0 commit comments