21
21
from kernel_tuner .util import check_restrictions as check_instance_restrictions
22
22
from kernel_tuner .util import (
23
23
compile_restrictions ,
24
+ convert_constraint_lambdas ,
24
25
default_block_size_names ,
25
26
get_interval ,
26
27
)
@@ -74,15 +75,14 @@ def __init__(
74
75
len (restrictions ) > 0
75
76
and any (isinstance (restriction , str ) for restriction in restrictions )
76
77
and not (
77
- framework_l == "pysmt" or framework_l == "bruteforce" or solver_method .lower () == "pc_parallelsolver"
78
+ framework_l == "pysmt" or framework_l == "bruteforce" or framework_l == "pythonconstraint" or solver_method .lower () == "pc_parallelsolver"
78
79
)
79
80
):
80
81
self .restrictions = compile_restrictions (
81
82
restrictions ,
82
83
tune_params ,
83
84
monolithic = False ,
84
85
format = framework_l if framework_l == "pyatf" else None ,
85
- try_to_constraint = framework_l == "pythonconstraint" ,
86
86
)
87
87
88
88
# get the framework given the framework argument
@@ -289,9 +289,7 @@ def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, so
289
289
# adding the default blocksize restriction requires recompilation because pyATF requires combined restrictions for the same parameter
290
290
max_block_size_product = f"{ ' * ' .join (valid_block_size_names )} <= { max_threads } "
291
291
restrictions = self ._modified_restrictions .copy () + [max_block_size_product ]
292
- self .restrictions = compile_restrictions (
293
- restrictions , self .tune_params , format = "pyatf" , try_to_constraint = False
294
- )
292
+ self .restrictions = compile_restrictions (restrictions , self .tune_params , format = "pyatf" )
295
293
296
294
# build a dictionary of the restrictions, combined based on last parameter
297
295
res_dict = dict ()
@@ -377,7 +375,7 @@ def __parameter_space_list_to_lookup_and_return_type(
377
375
parameter_space_dict ,
378
376
size_list ,
379
377
)
380
-
378
+
381
379
def __build_searchspace (self , block_size_names : list , max_threads : int , solver : Solver ):
382
380
"""Compute valid configurations in a search space based on restrictions and max_threads."""
383
381
# instantiate the parameter space with all the variables
@@ -386,6 +384,9 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
386
384
parameter_space .addVariable (str (param_name ), param_values )
387
385
388
386
# add the user-specified restrictions as constraints on the parameter space
387
+ if not isinstance (self .restrictions , (list , tuple )):
388
+ self .restrictions = [self .restrictions ]
389
+ self .restrictions = convert_constraint_lambdas (self .restrictions )
389
390
parameter_space = self .__add_restrictions (parameter_space )
390
391
391
392
# add the default blocksize threads restrictions last, because it is unlikely to reduce the parameter space by much
@@ -412,20 +413,29 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
412
413
for restriction in self .restrictions :
413
414
required_params = self .param_names
414
415
415
- # convert to a Constraint type if necessary
416
- if isinstance (restriction , tuple ):
417
- restriction , required_params , _ = restriction
416
+ # (un)wrap where necessary
417
+ if isinstance (restriction , tuple ) and len (restriction ) >= 2 :
418
+ required_params = restriction [1 ]
419
+ restriction = restriction [0 ]
418
420
if callable (restriction ) and not isinstance (restriction , Constraint ):
419
- restriction = FunctionConstraint (restriction )
420
-
421
- # add the Constraint
421
+ # def restrictions_wrapper(*args):
422
+ # return check_instance_restrictions(restriction, dict(zip(self.param_names, args)), False)
423
+ # print(restriction, isinstance(restriction, Constraint))
424
+ # restriction = FunctionConstraint(restrictions_wrapper)
425
+ restriction = FunctionConstraint (restriction , required_params )
426
+
427
+ # add as a Constraint
428
+ all_params_required = all (param_name in required_params for param_name in self .param_names )
429
+ variables = None if all_params_required else required_params
422
430
if isinstance (restriction , FunctionConstraint ):
423
- parameter_space .addConstraint (restriction , required_params )
431
+ parameter_space .addConstraint (restriction , variables )
424
432
elif isinstance (restriction , Constraint ):
425
- all_params_required = all (param_name in required_params for param_name in self .param_names )
426
- parameter_space .addConstraint (restriction , None if all_params_required else required_params )
427
- elif isinstance (restriction , str ) and self .solver_method .lower () == "pc_parallelsolver" :
428
- parameter_space .addConstraint (restriction )
433
+ parameter_space .addConstraint (restriction , variables )
434
+ elif isinstance (restriction , str ):
435
+ if self .solver_method .lower () == "pc_parallelsolver" :
436
+ parameter_space .addConstraint (restriction )
437
+ else :
438
+ parameter_space .addConstraint (restriction , variables )
429
439
else :
430
440
raise ValueError (f"Unrecognized restriction { restriction } " )
431
441
0 commit comments