Skip to content

Commit 49e786d

Browse files
committed
Search space construction can be deffered to a later time, split pyATF search space building and tunable parameter generation
1 parent 30de03e commit 49e786d

File tree

1 file changed

+51
-36
lines changed

1 file changed

+51
-36
lines changed

kernel_tuner/searchspace.py

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
restrictions,
4848
max_threads: int,
4949
block_size_names=default_block_size_names,
50+
defer_construction=False,
5051
build_neighbors_index=False,
5152
neighbor_method=None,
5253
from_cache: dict = None,
@@ -62,6 +63,7 @@ def __init__(
6263
Hamming: any parameter config with 1 different parameter value is a neighbor
6364
Optionally sort the searchspace by the order in which the parameter values were specified. By default, sort goes from first to last parameter, to reverse this use sort_last_param_first.
6465
Optionally an imported cache can be used instead with `from_cache`, in which case the `tune_params`, `restrictions` and `max_threads` arguments can be set to None, and construction is skipped.
66+
Optionally construction can be deffered to a later time by setting `defer_construction` to True, in which case the searchspace is not built on instantiation (experimental).
6567
"""
6668
# check the arguments
6769
if from_cache is not None:
@@ -76,7 +78,8 @@ def __init__(
7678
framework_l = framework.lower()
7779
restrictions = restrictions if restrictions is not None else []
7880
self.tune_params = tune_params
79-
self.tune_params_pyatf = None
81+
self.max_threads = max_threads
82+
self.block_size_names = block_size_names
8083
self._tensorspace = None
8184
self.tensor_dtype = torch.float32 if torch_available else None
8285
self.tensor_device = torch.device("cpu") if torch_available else None
@@ -160,17 +163,19 @@ def __init__(
160163
else:
161164
raise ValueError(f"Solver method {solver_method} not recognized.")
162165

163-
# build the search space
164-
self.list, self.__dict, self.size = searchspace_builder(block_size_names, max_threads, solver)
166+
if not defer_construction:
167+
# build the search space
168+
self.list, self.__dict, self.size = searchspace_builder(block_size_names, max_threads, solver)
165169

166170
# finalize construction
167-
self.__numpy = None
168-
self.num_params = len(self.tune_params)
169-
self.indices = np.arange(self.size)
170-
if neighbor_method is not None and neighbor_method != "Hamming":
171-
self.__prepare_neighbors_index()
172-
if build_neighbors_index:
173-
self.neighbors_index = self.__build_neighbors_index(neighbor_method)
171+
if not defer_construction:
172+
self.__numpy = None
173+
self.num_params = len(self.tune_params)
174+
self.indices = np.arange(self.size)
175+
if neighbor_method is not None and neighbor_method != "Hamming":
176+
self.__prepare_neighbors_index()
177+
if build_neighbors_index:
178+
self.neighbors_index = self.__build_neighbors_index(neighbor_method)
174179

175180
# def __build_searchspace_ortools(self, block_size_names: list, max_threads: int) -> Tuple[List[tuple], np.ndarray, dict, int]:
176181
# # Based on https://developers.google.com/optimization/cp/cp_solver#python_2
@@ -318,14 +323,15 @@ def all_smt(formula, keys) -> list:
318323

319324
return self.__parameter_space_list_to_lookup_and_return_type(parameter_space_list)
320325

321-
def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, solver: Solver):
322-
"""Builds the searchspace using pyATF."""
323-
from pyatf import TP, Interval, Set, Tuner
324-
from pyatf.cost_functions.generic import CostFunction
325-
from pyatf.search_techniques import Exhaustive
326+
def get_tune_params_pyatf(self, block_size_names: list = None, max_threads: int = None):
327+
"""Convert the tune_params and restrictions to pyATF tunable parameters."""
328+
from pyatf import TP, Interval, Set
326329

327-
# Define a bogus cost function
328-
costfunc = CostFunction(":") # bash no-op
330+
# if block_size_names or max_threads are not specified, use the defaults
331+
if block_size_names is None:
332+
block_size_names = self.block_size_names
333+
if max_threads is None:
334+
max_threads = self.max_threads
329335

330336
# add the Kernel Tuner default blocksize threads restrictions
331337
assert isinstance(self.restrictions, list)
@@ -359,27 +365,36 @@ def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, so
359365
registered_restrictions.append(index)
360366

361367
# define the Tunable Parameters
362-
def get_params():
363-
params = list()
364-
for index, (key, values) in enumerate(self.tune_params.items()):
365-
vi = get_interval(values)
366-
vals = (
367-
Interval(vi[0], vi[1], vi[2]) if vi is not None and vi[2] != 0 else Set(*np.array(values).flatten())
368-
)
369-
constraint = res_dict.get(key, None)
370-
constraint_source = None
371-
if constraint is not None:
372-
constraint, constraint_source = constraint
373-
# in case of a leftover monolithic restriction, append at the last parameter
374-
if index == len(self.tune_params) - 1 and len(res_dict) == 0 and len(self.restrictions) == 1:
375-
res, params, source = self.restrictions[0]
376-
assert callable(res)
377-
constraint = res
378-
params.append(TP(key, vals, constraint, constraint_source))
379-
return params
368+
params = list()
369+
for index, (key, values) in enumerate(self.tune_params.items()):
370+
vi = get_interval(values)
371+
vals = (
372+
Interval(vi[0], vi[1], vi[2]) if vi is not None and vi[2] != 0 else Set(*np.array(values).flatten())
373+
)
374+
constraint = res_dict.get(key, None)
375+
constraint_source = None
376+
if constraint is not None:
377+
constraint, constraint_source = constraint
378+
# in case of a leftover monolithic restriction, append at the last parameter
379+
if index == len(self.tune_params) - 1 and len(res_dict) == 0 and len(self.restrictions) == 1:
380+
res, params, source = self.restrictions[0]
381+
assert callable(res)
382+
constraint = res
383+
params.append(TP(key, vals, constraint, constraint_source))
384+
return params
385+
386+
387+
def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, solver: Solver):
388+
"""Builds the searchspace using pyATF."""
389+
from pyatf import Tuner
390+
from pyatf.cost_functions.generic import CostFunction
391+
from pyatf.search_techniques import Exhaustive
392+
393+
# Define a bogus cost function
394+
costfunc = CostFunction(":") # bash no-op
380395

381396
# set data
382-
self.tune_params_pyatf = get_params()
397+
self.tune_params_pyatf = self.get_tune_params_pyatf(block_size_names, max_threads)
383398

384399
# tune
385400
_, _, tuning_data = (

0 commit comments

Comments
 (0)