Skip to content

Commit 934be28

Browse files
committed
Implemented caching of PyATF searchspace object, both storage and retrieval
1 parent 36208d1 commit 934be28

File tree

1 file changed

+41
-216
lines changed

1 file changed

+41
-216
lines changed
Lines changed: 41 additions & 216 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Strategy that dynamically imports and enables the use of pyATF strategies."""
22

33
from importlib import import_module
4+
import zlib
5+
from pathlib import Path
46

57
from kernel_tuner.searchspace import Searchspace
68
from kernel_tuner.strategies import common
@@ -11,9 +13,26 @@
1113

1214
_options = dict(searchtechnique=(f"PyATF optimization algorithm to use, choose any from {supported_searchtechniques}", "simulated_annealing"))
1315

16+
def get_cache_checksum(d: dict):
17+
checksum=0
18+
for item in d.items():
19+
c1 = 1
20+
for t in item:
21+
c1 = zlib.adler32(bytes(repr(t),'utf-8'), c1)
22+
checksum=checksum ^ c1
23+
return checksum
24+
1425
def tune(searchspace: Searchspace, runner, tuning_options):
1526
from pyatf.search_techniques.search_technique import SearchTechnique
1627
from pyatf.search_space import SearchSpace as pyATFSearchSpace
28+
from pyatf import TP
29+
try:
30+
import dill
31+
pyatf_search_space_caching = True
32+
except ImportError:
33+
from warnings import warn
34+
pyatf_search_space_caching = False
35+
warn("dill is not installed, pyATF search space caching will not be used.")
1736

1837
# setup the Kernel Tuner functionalities
1938
cost_func = CostFunc(searchspace, tuning_options, runner, scaling=False, snap=False, return_invalid=False)
@@ -29,18 +48,29 @@ def tune(searchspace: Searchspace, runner, tuning_options):
2948
search_technique.initialize(len(searchspace.param_names))
3049
assert isinstance(search_technique, SearchTechnique), f"Search technique {search_technique} is not a valid pyATF search technique."
3150

51+
# get the search space hash
52+
tune_params_hashable = {k: ",".join([str(i) for i in v]) if isinstance(v, (list, tuple)) else v for k, v in searchspace.tune_params.items()}
53+
searchspace_caches_folder = Path("./pyatf_searchspace_caches")
54+
searchspace_caches_folder.mkdir(parents=True, exist_ok=True)
55+
searchspace_cache_path = searchspace_caches_folder / Path(f"pyatf_searchspace_cache_{get_cache_checksum(tune_params_hashable)}.pkl")
56+
3257
# initialize the search space
33-
searchspace_pyatf = Searchspace(
34-
searchspace.tune_params,
35-
tuning_options.restrictions_unmodified,
36-
searchspace.max_threads,
37-
searchspace.block_size_names,
38-
defer_construction=True,
39-
framework="pyatf"
40-
)
41-
tune_params_pyatf = searchspace_pyatf.get_tune_params_pyatf()
42-
assert isinstance(tune_params_pyatf, (tuple, list)), f"Tuning parameters must be a tuple or list of tuples, is {type(tune_params_pyatf)} ({tune_params_pyatf})."
43-
search_space_pyatf = pyATFSearchSpace(*tune_params_pyatf, enable_1d_access=False) # SearchTechnique1D currently not supported
58+
if not pyatf_search_space_caching or not searchspace_cache_path.exists():
59+
searchspace_pyatf = Searchspace(
60+
searchspace.tune_params,
61+
tuning_options.restrictions_unmodified,
62+
searchspace.max_threads,
63+
searchspace.block_size_names,
64+
defer_construction=True,
65+
framework="pyatf"
66+
)
67+
tune_params_pyatf = searchspace_pyatf.get_tune_params_pyatf()
68+
assert isinstance(tune_params_pyatf, (tuple, list)), f"Tuning parameters must be a tuple or list of tuples, is {type(tune_params_pyatf)} ({tune_params_pyatf})."
69+
search_space_pyatf = pyATFSearchSpace(*tune_params_pyatf, enable_1d_access=False) # SearchTechnique1D currently not supported
70+
if pyatf_search_space_caching:
71+
dill.dump(search_space_pyatf, open(searchspace_cache_path, "wb"))
72+
elif searchspace_cache_path.exists():
73+
search_space_pyatf = dill.load(open(searchspace_cache_path, "rb"))
4474

4575
# initialize
4676
get_next_coordinates_or_indices = search_technique.get_next_coordinates
@@ -86,209 +116,4 @@ def tune(searchspace: Searchspace, runner, tuning_options):
86116
return cost_func.results
87117

88118

89-
# class TuningRun:
90-
# def __init__(self,
91-
# search_space: SearchSpace | Tuple[TP, ...],
92-
# cost_function: CostFunction,
93-
# search_technique: Optional[Union[SearchTechnique, SearchTechnique1D]],
94-
# verbosity: Optional[int],
95-
# log_file: Optional[str],
96-
# abort_condition: Optional[AbortCondition]):
97-
# if search_space is None:
98-
# raise ValueError('missing call to `Tuner.tuning_parameters(...)`: no tuning parameters defined')
99-
100-
# # tuning data
101-
# self._search_space: SearchSpace
102-
# self._search_technique: SearchTechnique | SearchTechnique1D
103-
# self._abort_condition: AbortCondition # TODO: does not work (add initialization)
104-
# self._tps: Tuple[TP, ...]
105-
# self._tuning_data: Optional[TuningData] = None
106-
# self._cost_function: CostFunction = cost_function
107-
108-
# # progress data
109-
# self._verbosity = verbosity
110-
# self._log_file: Optional[TextIO] = None
111-
# self._last_log_dump: Optional[int] = None
112-
# self._last_line_length: Optional[int] = None
113-
# self._tuning_start_ns: Optional[int] = None
114-
115-
# # prepare search technique
116-
# self._search_technique: SearchTechnique | SearchTechnique1D = search_technique
117-
# if self._search_technique is None:
118-
# self._search_technique = AUCBandit()
119-
# if isinstance(self._search_technique, SearchTechnique):
120-
# self._get_next_coordinates_or_indices = self._search_technique.get_next_coordinates
121-
# self._coordinates_or_index_param_name = 'search_space_coordinates'
122-
# else:
123-
# self._get_next_coordinates_or_indices = self._search_technique.get_next_indices
124-
# self._coordinates_or_index_param_name = 'search_space_index'
125-
# self._coordinates_or_indices: Set[Union[Coordinates, Index]] = set()
126-
# self._costs: Dict[Union[Coordinates, Index], Cost] = {}
127-
128-
# # generate search space
129-
# if isinstance(search_space, SearchSpace):
130-
# self._search_space = search_space
131-
# else:
132-
# self._search_space = SearchSpace(*search_space,
133-
# enable_1d_access=isinstance(self._search_technique, SearchTechnique1D),
134-
# verbosity=verbosity)
135-
# self._tps = self._search_space.tps
136-
# self._search_space_generation_ns = self._search_space.generation_ns
137-
# if self._verbosity >= 2:
138-
# print(f'search space size: {self._search_space.constrained_size}')
139-
140-
# # prepare abort condition
141-
# self._abort_condition = abort_condition
142-
# if self._abort_condition is None:
143-
# self._abort_condition = Evaluations(len(self._search_space))
144-
145-
# # open log file
146-
# if log_file:
147-
# Path(log_file).parent.mkdir(parents=True, exist_ok=True)
148-
# self._log_file = open(log_file, 'w')
149-
150-
# def __del__(self):
151-
# if self._log_file:
152-
# self._log_file.close()
153-
154-
# @property
155-
# def cost_function(self):
156-
# return self._cost_function
157-
158-
# @property
159-
# def abort_condition(self):
160-
# return self._abort_condition
161-
162-
# @property
163-
# def tuning_data(self):
164-
# return self._tuning_data
165-
166-
# def flush_log(self):
167-
# if self._log_file:
168-
# self._log_file.seek(0)
169-
# json.dump(self._tuning_data.to_json(), self._log_file, indent=4)
170-
# self._log_file.truncate()
171-
# self._last_log_dump = time.perf_counter_ns()
172-
173-
# def _print_progress(self, timestamp: datetime, cost: Optional[Cost] = None):
174-
# now = time.perf_counter_ns()
175-
# elapsed_ns = now - self._tuning_start_ns
176-
# elapsed_seconds = elapsed_ns // 1000000000
177-
# elapsed_time_str = (f'{elapsed_seconds // 3600}'
178-
# f':{elapsed_seconds // 60 % 60:02d}'
179-
# f':{elapsed_seconds % 60:02d}')
180-
# progress = self._abort_condition.progress(self._tuning_data)
181-
# if self._verbosity >= 3:
182-
# line = (f'\r{timestamp.strftime("%Y-%m-%dT%H:%M:%S")}'
183-
# f' evaluations: {self._tuning_data.number_of_evaluated_configurations}'
184-
# f' (valid: {self._tuning_data.number_of_evaluated_valid_configurations})'
185-
# f', min. cost: {self._tuning_data.min_cost()}'
186-
# f', valid: {cost is not None}'
187-
# f', cost: {cost}')
188-
# line_length = len(line)
189-
# if line_length < self._last_line_length:
190-
# line += ' ' * (self._last_line_length - line_length)
191-
# print(line)
192-
# if progress is None:
193-
# spinner_char = ('-', '\\', '|', '/')[(elapsed_ns // 500000000) % 4]
194-
# line = f'\rTuning: {spinner_char} {elapsed_time_str}\r'
195-
# print(line, end='')
196-
# else:
197-
# if now > self._tuning_start_ns and progress > 0:
198-
# eta_seconds = ceil(((now - self._tuning_start_ns) / progress
199-
# * (1 - progress)) / 1000000000)
200-
# eta_str = (f'{eta_seconds // 3600}'
201-
# f':{eta_seconds // 60 % 60:02d}'
202-
# f':{eta_seconds % 60:02d}')
203-
# else:
204-
# eta_str = '?'
205-
# filled = '█' * floor(progress * 80)
206-
# empty = ' ' * ceil((1 - progress) * 80)
207-
# line = (f'\rexploring search space: |{filled}{empty}|'
208-
# f' {progress * 100:6.2f}% {elapsed_time_str} (ETA: {eta_str})')
209-
# print(line, end='')
210-
# self._last_line_length = len(line)
211-
212-
# def initialize(self):
213-
# # reset progress data
214-
# self._tuning_start_ns = time.perf_counter_ns()
215-
# self._last_line_length = 0
216-
217-
# # create tuning data
218-
# self._tuning_data = TuningData(list(tp.to_json() for tp in self._tps),
219-
# self._search_space.constrained_size,
220-
# self._search_space.unconstrained_size,
221-
# self._search_space_generation_ns,
222-
# self._search_technique.to_json(),
223-
# self._abort_condition.to_json())
224-
225-
# # write tuning data
226-
# self.flush_log()
227-
228-
# # initialize search technique
229-
# if isinstance(self._search_technique, SearchTechnique1D):
230-
# self._search_technique.initialize(len(self._search_space))
231-
# else:
232-
# self._search_technique.initialize(self._search_space.num_tps)
233-
234-
# def make_step(self):
235-
# # get new coordinates
236-
# if not self._coordinates_or_indices:
237-
# if self._costs:
238-
# self._search_technique.report_costs(self._costs)
239-
# self._costs.clear()
240-
# self._coordinates_or_indices.update(self._get_next_coordinates_or_indices())
241-
242-
# # get configuration
243-
# coords_or_index = self._coordinates_or_indices.pop()
244-
# config = self._search_space.get_configuration(coords_or_index)
245-
246-
# # run cost function
247-
# valid = True
248-
# try:
249-
# cost = self._cost_function(config)
250-
# except CostFunctionError as e:
251-
# if self._verbosity >= 3:
252-
# print('\r' + ' ' * self._last_line_length + '\r', end='')
253-
# print('Error raised: ' + e.message)
254-
# self._last_line_length = 0
255-
# cost = None
256-
# valid = False
257-
# except BaseException as e:
258-
# self._tuning_data.record_evaluation(config, False, None, **{
259-
# self._coordinates_or_index_param_name: coords_or_index
260-
# })
261-
# self.flush_log()
262-
# raise e
263-
# timestamp = self._tuning_data.record_evaluation(config, valid, cost, **{
264-
# self._coordinates_or_index_param_name: coords_or_index
265-
# })
266-
# self._costs[coords_or_index] = cost
267-
268-
# # print progress and dump log file (at most once every 5 minutes)
269-
# if self._verbosity >= 1:
270-
# self._print_progress(timestamp, cost)
271-
# if self._log_file and (self._last_log_dump is None or time.perf_counter_ns() - self._last_log_dump > 3e11):
272-
# self.flush_log()
273-
274-
# def finalize(self, sigint_received: bool = False):
275-
# self._search_technique.finalize()
276-
# self._tuning_data.record_tuning_finished(sigint_received)
277-
278-
# # write tuning data to file
279-
# if self._log_file:
280-
# self.flush_log()
281-
# self._log_file.close()
282-
# self._log_file = None
283-
284-
# if self._verbosity >= 1:
285-
# print('\nfinished tuning')
286-
# if self._verbosity >= 2:
287-
# if self._tuning_data.min_cost() is not None:
288-
# print('best configuration:')
289-
# for tp_name, tp_value in self._tuning_data.configuration_of_min_cost().items():
290-
# print(f' {tp_name} = {tp_value}')
291-
# print(f'min cost: {self._tuning_data.min_cost()}')
292-
293-
294119
tune.__doc__ = common.get_strategy_docstring("pyatf_strategies", _options)

0 commit comments

Comments
 (0)