Skip to content

Commit 3d12130

Browse files
committed
Completed merge with recent changes to searchspace
2 parents 08a1029 + faeb52e commit 3d12130

File tree

12 files changed

+225
-261
lines changed

12 files changed

+225
-261
lines changed

kernel_tuner/interface.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -614,16 +614,24 @@ def tune_kernel(
614614
tuning_options = Options([(k, opts[k]) for k in _tuning_options.keys()])
615615
device_options = Options([(k, opts[k]) for k in _device_options.keys()])
616616
tuning_options["unique_results"] = {}
617-
if strategy_options and "max_fevals" in strategy_options:
618-
tuning_options["max_fevals"] = strategy_options["max_fevals"]
619-
if strategy_options and "time_limit" in strategy_options:
620-
tuning_options["time_limit"] = strategy_options["time_limit"]
621617

618+
# copy some values from strategy_options
619+
searchspace_construction_options = {}
620+
if strategy_options:
621+
if "max_fevals" in strategy_options:
622+
tuning_options["max_fevals"] = strategy_options["max_fevals"]
623+
if "time_limit" in strategy_options:
624+
tuning_options["time_limit"] = strategy_options["time_limit"]
625+
if "searchspace_construction_options" in strategy_options:
626+
searchspace_construction_options = strategy_options["searchspace_construction_options"]
627+
628+
# log the user inputs
622629
logging.debug("tune_kernel called")
623630
logging.debug("kernel_options: %s", util.get_config_string(kernel_options))
624631
logging.debug("tuning_options: %s", util.get_config_string(tuning_options))
625632
logging.debug("device_options: %s", util.get_config_string(device_options))
626633

634+
# check whether the selected strategy and options are valid
627635
strategy_string = strategy
628636
if strategy:
629637
if strategy in strategy_map:
@@ -669,7 +677,7 @@ def preprocess_cache(filepath):
669677

670678
# create search space
671679
tuning_options.restrictions_unmodified = deepcopy(restrictions)
672-
searchspace = Searchspace(tune_params, restrictions, runner.dev.max_threads)
680+
searchspace = Searchspace(tune_params, restrictions, runner.dev.max_threads, **searchspace_construction_options)
673681
restrictions = searchspace._modified_restrictions
674682
tuning_options.restrictions = restrictions
675683
if verbose:

kernel_tuner/searchspace.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from kernel_tuner.util import check_restrictions as check_instance_restrictions
3232
from kernel_tuner.util import (
3333
compile_restrictions,
34+
convert_constraint_lambdas,
3435
default_block_size_names,
3536
get_interval,
3637
)
@@ -114,15 +115,14 @@ def __init__(
114115
)
115116
)
116117
and not (
117-
framework_l == "pysmt" or framework_l == "bruteforce" or solver_method.lower() == "pc_parallelsolver"
118+
framework_l == "pysmt" or framework_l == "bruteforce" or framework_l == "pythonconstraint" or solver_method.lower() == "pc_parallelsolver"
118119
)
119120
):
120121
self.restrictions = compile_restrictions(
121122
restrictions,
122123
tune_params,
123124
monolithic=False,
124125
format=framework_l if framework_l == "pyatf" else None,
125-
try_to_constraint=framework_l == "pythonconstraint",
126126
)
127127

128128
# if an imported cache, skip building and set the values directly
@@ -342,9 +342,7 @@ def get_tune_params_pyatf(self, block_size_names: list = None, max_threads: int
342342
# adding the default blocksize restriction requires recompilation because pyATF requires combined restrictions for the same parameter
343343
max_block_size_product = f"{' * '.join(valid_block_size_names)} <= {max_threads}"
344344
restrictions = self._modified_restrictions.copy() + [max_block_size_product]
345-
self.restrictions = compile_restrictions(
346-
restrictions, self.tune_params, format="pyatf", try_to_constraint=False
347-
)
345+
self.restrictions = compile_restrictions(restrictions, self.tune_params, format="pyatf")
348346

349347
# build a dictionary of the restrictions, combined based on last parameter
350348
res_dict = dict()
@@ -371,6 +369,7 @@ def get_tune_params_pyatf(self, block_size_names: list = None, max_threads: int
371369
vals = (
372370
Interval(vi[0], vi[1], vi[2]) if vi is not None and vi[2] != 0 else Set(*np.array(values).flatten())
373371
)
372+
assert vals is not None, f"Values for parameter {key} are None, this should not happen."
374373
constraint = res_dict.get(key, None)
375374
constraint_source = None
376375
if constraint is not None:
@@ -442,7 +441,7 @@ def __parameter_space_list_to_lookup_and_return_type(
442441
parameter_space_dict,
443442
size_list,
444443
)
445-
444+
446445
def __build_searchspace(self, block_size_names: list, max_threads: int, solver: Solver):
447446
"""Compute valid configurations in a search space based on restrictions and max_threads."""
448447
# instantiate the parameter space with all the variables
@@ -451,6 +450,9 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
451450
parameter_space.addVariable(str(param_name), param_values)
452451

453452
# add the user-specified restrictions as constraints on the parameter space
453+
if not isinstance(self.restrictions, (list, tuple)):
454+
self.restrictions = [self.restrictions]
455+
self.restrictions = convert_constraint_lambdas(self.restrictions)
454456
parameter_space = self.__add_restrictions(parameter_space)
455457

456458
# add the default blocksize threads restrictions last, because it is unlikely to reduce the parameter space by much
@@ -477,20 +479,29 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
477479
for restriction in self.restrictions:
478480
required_params = self.param_names
479481

480-
# convert to a Constraint type if necessary
481-
if isinstance(restriction, tuple):
482-
restriction, required_params, _ = restriction
482+
# (un)wrap where necessary
483+
if isinstance(restriction, tuple) and len(restriction) >= 2:
484+
required_params = restriction[1]
485+
restriction = restriction[0]
483486
if callable(restriction) and not isinstance(restriction, Constraint):
484-
restriction = FunctionConstraint(restriction)
485-
486-
# add the Constraint
487+
# def restrictions_wrapper(*args):
488+
# return check_instance_restrictions(restriction, dict(zip(self.param_names, args)), False)
489+
# print(restriction, isinstance(restriction, Constraint))
490+
# restriction = FunctionConstraint(restrictions_wrapper)
491+
restriction = FunctionConstraint(restriction, required_params)
492+
493+
# add as a Constraint
494+
all_params_required = all(param_name in required_params for param_name in self.param_names)
495+
variables = None if all_params_required else required_params
487496
if isinstance(restriction, FunctionConstraint):
488-
parameter_space.addConstraint(restriction, required_params)
497+
parameter_space.addConstraint(restriction, variables)
489498
elif isinstance(restriction, Constraint):
490-
all_params_required = all(param_name in required_params for param_name in self.param_names)
491-
parameter_space.addConstraint(restriction, None if all_params_required else required_params)
492-
elif isinstance(restriction, str) and self.solver_method.lower() == "pc_parallelsolver":
493-
parameter_space.addConstraint(restriction)
499+
parameter_space.addConstraint(restriction, variables)
500+
elif isinstance(restriction, str):
501+
if self.solver_method.lower() == "pc_parallelsolver":
502+
parameter_space.addConstraint(restriction)
503+
else:
504+
parameter_space.addConstraint(restriction, variables)
494505
else:
495506
raise ValueError(f"Unrecognized restriction type {type(restriction)} ({restriction})")
496507

kernel_tuner/strategies/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ def make_strategy_options_doc(strategy_options):
5151

5252
def get_options(strategy_options, options):
5353
"""Get the strategy-specific options or their defaults from user-supplied strategy_options."""
54-
accepted = list(options.keys()) + ["max_fevals", "time_limit"]
54+
accepted = list(options.keys()) + ["max_fevals", "time_limit", "searchspace_construction_options"]
5555
for key in strategy_options:
5656
if key not in accepted:
57-
raise ValueError(f"Unrecognized option {key} in strategy_options")
57+
raise ValueError(f"Unrecognized option {key} in strategy_options (allowed: {accepted})")
5858
assert isinstance(options, dict)
5959
return [strategy_options.get(opt, default) for opt, (_, default) in options.items()]
6060

0 commit comments

Comments
 (0)