Skip to content

Commit faeb52e

Browse files
committed
Moved compiling and parsing of restrictions to python-constraint where possible, amended tests where necessary
1 parent 5b5dbdb commit faeb52e

File tree

4 files changed

+57
-250
lines changed

4 files changed

+57
-250
lines changed

kernel_tuner/searchspace.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from kernel_tuner.util import check_restrictions as check_instance_restrictions
2222
from kernel_tuner.util import (
2323
compile_restrictions,
24+
convert_constraint_lambdas,
2425
default_block_size_names,
2526
get_interval,
2627
)
@@ -74,15 +75,14 @@ def __init__(
7475
len(restrictions) > 0
7576
and any(isinstance(restriction, str) for restriction in restrictions)
7677
and not (
77-
framework_l == "pysmt" or framework_l == "bruteforce" or solver_method.lower() == "pc_parallelsolver"
78+
framework_l == "pysmt" or framework_l == "bruteforce" or framework_l == "pythonconstraint" or solver_method.lower() == "pc_parallelsolver"
7879
)
7980
):
8081
self.restrictions = compile_restrictions(
8182
restrictions,
8283
tune_params,
8384
monolithic=False,
8485
format=framework_l if framework_l == "pyatf" else None,
85-
try_to_constraint=framework_l == "pythonconstraint",
8686
)
8787

8888
# get the framework given the framework argument
@@ -289,9 +289,7 @@ def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, so
289289
# adding the default blocksize restriction requires recompilation because pyATF requires combined restrictions for the same parameter
290290
max_block_size_product = f"{' * '.join(valid_block_size_names)} <= {max_threads}"
291291
restrictions = self._modified_restrictions.copy() + [max_block_size_product]
292-
self.restrictions = compile_restrictions(
293-
restrictions, self.tune_params, format="pyatf", try_to_constraint=False
294-
)
292+
self.restrictions = compile_restrictions(restrictions, self.tune_params, format="pyatf")
295293

296294
# build a dictionary of the restrictions, combined based on last parameter
297295
res_dict = dict()
@@ -377,7 +375,7 @@ def __parameter_space_list_to_lookup_and_return_type(
377375
parameter_space_dict,
378376
size_list,
379377
)
380-
378+
381379
def __build_searchspace(self, block_size_names: list, max_threads: int, solver: Solver):
382380
"""Compute valid configurations in a search space based on restrictions and max_threads."""
383381
# instantiate the parameter space with all the variables
@@ -386,6 +384,9 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
386384
parameter_space.addVariable(str(param_name), param_values)
387385

388386
# add the user-specified restrictions as constraints on the parameter space
387+
if not isinstance(self.restrictions, (list, tuple)):
388+
self.restrictions = [self.restrictions]
389+
self.restrictions = convert_constraint_lambdas(self.restrictions)
389390
parameter_space = self.__add_restrictions(parameter_space)
390391

391392
# add the default blocksize threads restrictions last, because it is unlikely to reduce the parameter space by much
@@ -412,20 +413,29 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
412413
for restriction in self.restrictions:
413414
required_params = self.param_names
414415

415-
# convert to a Constraint type if necessary
416-
if isinstance(restriction, tuple):
417-
restriction, required_params, _ = restriction
416+
# (un)wrap where necessary
417+
if isinstance(restriction, tuple) and len(restriction) >= 2:
418+
required_params = restriction[1]
419+
restriction = restriction[0]
418420
if callable(restriction) and not isinstance(restriction, Constraint):
419-
restriction = FunctionConstraint(restriction)
420-
421-
# add the Constraint
421+
# def restrictions_wrapper(*args):
422+
# return check_instance_restrictions(restriction, dict(zip(self.param_names, args)), False)
423+
# print(restriction, isinstance(restriction, Constraint))
424+
# restriction = FunctionConstraint(restrictions_wrapper)
425+
restriction = FunctionConstraint(restriction, required_params)
426+
427+
# add as a Constraint
428+
all_params_required = all(param_name in required_params for param_name in self.param_names)
429+
variables = None if all_params_required else required_params
422430
if isinstance(restriction, FunctionConstraint):
423-
parameter_space.addConstraint(restriction, required_params)
431+
parameter_space.addConstraint(restriction, variables)
424432
elif isinstance(restriction, Constraint):
425-
all_params_required = all(param_name in required_params for param_name in self.param_names)
426-
parameter_space.addConstraint(restriction, None if all_params_required else required_params)
427-
elif isinstance(restriction, str) and self.solver_method.lower() == "pc_parallelsolver":
428-
parameter_space.addConstraint(restriction)
433+
parameter_space.addConstraint(restriction, variables)
434+
elif isinstance(restriction, str):
435+
if self.solver_method.lower() == "pc_parallelsolver":
436+
parameter_space.addConstraint(restriction)
437+
else:
438+
parameter_space.addConstraint(restriction, variables)
429439
else:
430440
raise ValueError(f"Unrecognized restriction {restriction}")
431441

0 commit comments

Comments
 (0)