31
31
from kernel_tuner .util import check_restrictions as check_instance_restrictions
32
32
from kernel_tuner .util import (
33
33
compile_restrictions ,
34
+ convert_constraint_lambdas ,
34
35
default_block_size_names ,
35
36
get_interval ,
36
37
)
@@ -114,15 +115,14 @@ def __init__(
114
115
)
115
116
)
116
117
and not (
117
- framework_l == "pysmt" or framework_l == "bruteforce" or solver_method .lower () == "pc_parallelsolver"
118
+ framework_l == "pysmt" or framework_l == "bruteforce" or framework_l == "pythonconstraint" or solver_method .lower () == "pc_parallelsolver"
118
119
)
119
120
):
120
121
self .restrictions = compile_restrictions (
121
122
restrictions ,
122
123
tune_params ,
123
124
monolithic = False ,
124
125
format = framework_l if framework_l == "pyatf" else None ,
125
- try_to_constraint = framework_l == "pythonconstraint" ,
126
126
)
127
127
128
128
# if an imported cache, skip building and set the values directly
@@ -342,9 +342,7 @@ def get_tune_params_pyatf(self, block_size_names: list = None, max_threads: int
342
342
# adding the default blocksize restriction requires recompilation because pyATF requires combined restrictions for the same parameter
343
343
max_block_size_product = f"{ ' * ' .join (valid_block_size_names )} <= { max_threads } "
344
344
restrictions = self ._modified_restrictions .copy () + [max_block_size_product ]
345
- self .restrictions = compile_restrictions (
346
- restrictions , self .tune_params , format = "pyatf" , try_to_constraint = False
347
- )
345
+ self .restrictions = compile_restrictions (restrictions , self .tune_params , format = "pyatf" )
348
346
349
347
# build a dictionary of the restrictions, combined based on last parameter
350
348
res_dict = dict ()
@@ -371,6 +369,7 @@ def get_tune_params_pyatf(self, block_size_names: list = None, max_threads: int
371
369
vals = (
372
370
Interval (vi [0 ], vi [1 ], vi [2 ]) if vi is not None and vi [2 ] != 0 else Set (* np .array (values ).flatten ())
373
371
)
372
+ assert vals is not None , f"Values for parameter { key } are None, this should not happen."
374
373
constraint = res_dict .get (key , None )
375
374
constraint_source = None
376
375
if constraint is not None :
@@ -442,7 +441,7 @@ def __parameter_space_list_to_lookup_and_return_type(
442
441
parameter_space_dict ,
443
442
size_list ,
444
443
)
445
-
444
+
446
445
def __build_searchspace (self , block_size_names : list , max_threads : int , solver : Solver ):
447
446
"""Compute valid configurations in a search space based on restrictions and max_threads."""
448
447
# instantiate the parameter space with all the variables
@@ -451,6 +450,9 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
451
450
parameter_space .addVariable (str (param_name ), param_values )
452
451
453
452
# add the user-specified restrictions as constraints on the parameter space
453
+ if not isinstance (self .restrictions , (list , tuple )):
454
+ self .restrictions = [self .restrictions ]
455
+ self .restrictions = convert_constraint_lambdas (self .restrictions )
454
456
parameter_space = self .__add_restrictions (parameter_space )
455
457
456
458
# add the default blocksize threads restrictions last, because it is unlikely to reduce the parameter space by much
@@ -477,20 +479,29 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
477
479
for restriction in self .restrictions :
478
480
required_params = self .param_names
479
481
480
- # convert to a Constraint type if necessary
481
- if isinstance (restriction , tuple ):
482
- restriction , required_params , _ = restriction
482
+ # (un)wrap where necessary
483
+ if isinstance (restriction , tuple ) and len (restriction ) >= 2 :
484
+ required_params = restriction [1 ]
485
+ restriction = restriction [0 ]
483
486
if callable (restriction ) and not isinstance (restriction , Constraint ):
484
- restriction = FunctionConstraint (restriction )
485
-
486
- # add the Constraint
487
+ # def restrictions_wrapper(*args):
488
+ # return check_instance_restrictions(restriction, dict(zip(self.param_names, args)), False)
489
+ # print(restriction, isinstance(restriction, Constraint))
490
+ # restriction = FunctionConstraint(restrictions_wrapper)
491
+ restriction = FunctionConstraint (restriction , required_params )
492
+
493
+ # add as a Constraint
494
+ all_params_required = all (param_name in required_params for param_name in self .param_names )
495
+ variables = None if all_params_required else required_params
487
496
if isinstance (restriction , FunctionConstraint ):
488
- parameter_space .addConstraint (restriction , required_params )
497
+ parameter_space .addConstraint (restriction , variables )
489
498
elif isinstance (restriction , Constraint ):
490
- all_params_required = all (param_name in required_params for param_name in self .param_names )
491
- parameter_space .addConstraint (restriction , None if all_params_required else required_params )
492
- elif isinstance (restriction , str ) and self .solver_method .lower () == "pc_parallelsolver" :
493
- parameter_space .addConstraint (restriction )
499
+ parameter_space .addConstraint (restriction , variables )
500
+ elif isinstance (restriction , str ):
501
+ if self .solver_method .lower () == "pc_parallelsolver" :
502
+ parameter_space .addConstraint (restriction )
503
+ else :
504
+ parameter_space .addConstraint (restriction , variables )
494
505
else :
495
506
raise ValueError (f"Unrecognized restriction type { type (restriction )} ({ restriction } )" )
496
507
0 commit comments