2121from kernel_tuner .util import check_restrictions as check_instance_restrictions
2222from kernel_tuner .util import (
2323 compile_restrictions ,
24+ convert_constraint_lambdas ,
2425 default_block_size_names ,
2526 get_interval ,
2627)
@@ -74,15 +75,14 @@ def __init__(
7475 len (restrictions ) > 0
7576 and any (isinstance (restriction , str ) for restriction in restrictions )
7677 and not (
77- framework_l == "pysmt" or framework_l == "bruteforce" or solver_method .lower () == "pc_parallelsolver"
78+ framework_l == "pysmt" or framework_l == "bruteforce" or framework_l == "pythonconstraint" or solver_method .lower () == "pc_parallelsolver"
7879 )
7980 ):
8081 self .restrictions = compile_restrictions (
8182 restrictions ,
8283 tune_params ,
8384 monolithic = False ,
8485 format = framework_l if framework_l == "pyatf" else None ,
85- try_to_constraint = framework_l == "pythonconstraint" ,
8686 )
8787
8888 # get the framework given the framework argument
@@ -289,9 +289,7 @@ def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, so
289289 # adding the default blocksize restriction requires recompilation because pyATF requires combined restrictions for the same parameter
290290 max_block_size_product = f"{ ' * ' .join (valid_block_size_names )} <= { max_threads } "
291291 restrictions = self ._modified_restrictions .copy () + [max_block_size_product ]
292- self .restrictions = compile_restrictions (
293- restrictions , self .tune_params , format = "pyatf" , try_to_constraint = False
294- )
292+ self .restrictions = compile_restrictions (restrictions , self .tune_params , format = "pyatf" )
295293
296294 # build a dictionary of the restrictions, combined based on last parameter
297295 res_dict = dict ()
@@ -377,7 +375,7 @@ def __parameter_space_list_to_lookup_and_return_type(
377375 parameter_space_dict ,
378376 size_list ,
379377 )
380-
378+
381379 def __build_searchspace (self , block_size_names : list , max_threads : int , solver : Solver ):
382380 """Compute valid configurations in a search space based on restrictions and max_threads."""
383381 # instantiate the parameter space with all the variables
@@ -386,6 +384,9 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
386384 parameter_space .addVariable (str (param_name ), param_values )
387385
388386 # add the user-specified restrictions as constraints on the parameter space
387+ if not isinstance (self .restrictions , (list , tuple )):
388+ self .restrictions = [self .restrictions ]
389+ self .restrictions = convert_constraint_lambdas (self .restrictions )
389390 parameter_space = self .__add_restrictions (parameter_space )
390391
391392 # add the default blocksize threads restrictions last, because it is unlikely to reduce the parameter space by much
@@ -412,20 +413,29 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
412413 for restriction in self .restrictions :
413414 required_params = self .param_names
414415
415- # convert to a Constraint type if necessary
416- if isinstance (restriction , tuple ):
417- restriction , required_params , _ = restriction
416+ # (un)wrap where necessary
417+ if isinstance (restriction , tuple ) and len (restriction ) >= 2 :
418+ required_params = restriction [1 ]
419+ restriction = restriction [0 ]
418420 if callable (restriction ) and not isinstance (restriction , Constraint ):
419- restriction = FunctionConstraint (restriction )
420-
421- # add the Constraint
421+ # def restrictions_wrapper(*args):
422+ # return check_instance_restrictions(restriction, dict(zip(self.param_names, args)), False)
423+ # print(restriction, isinstance(restriction, Constraint))
424+ # restriction = FunctionConstraint(restrictions_wrapper)
425+ restriction = FunctionConstraint (restriction , required_params )
426+
427+ # add as a Constraint
428+ all_params_required = all (param_name in required_params for param_name in self .param_names )
429+ variables = None if all_params_required else required_params
422430 if isinstance (restriction , FunctionConstraint ):
423- parameter_space .addConstraint (restriction , required_params )
431+ parameter_space .addConstraint (restriction , variables )
424432 elif isinstance (restriction , Constraint ):
425- all_params_required = all (param_name in required_params for param_name in self .param_names )
426- parameter_space .addConstraint (restriction , None if all_params_required else required_params )
427- elif isinstance (restriction , str ) and self .solver_method .lower () == "pc_parallelsolver" :
428- parameter_space .addConstraint (restriction )
433+ parameter_space .addConstraint (restriction , variables )
434+ elif isinstance (restriction , str ):
435+ if self .solver_method .lower () == "pc_parallelsolver" :
436+ parameter_space .addConstraint (restriction )
437+ else :
438+ parameter_space .addConstraint (restriction , variables )
429439 else :
430440 raise ValueError (f"Unrecognized restriction { restriction } " )
431441
0 commit comments