@@ -496,7 +496,12 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
496496 def __add_restrictions (self , parameter_space : Problem ) -> Problem :
497497 """Add the user-specified restrictions as constraints on the parameter space."""
498498 restrictions = deepcopy (self .restrictions )
499- if len (restrictions ) == 1 and not isinstance (restrictions [0 ], (Constraint , FunctionConstraint , str )) and callable (restrictions [0 ]) and len (signature (restrictions [0 ]).parameters ) == 1 :
499+ # differentiate between old style monolithic with single 'p' argument and newer *args style
500+ if (len (restrictions ) == 1
501+ and not isinstance (restrictions [0 ], (Constraint , FunctionConstraint , str ))
502+ and callable (restrictions [0 ])
503+ and len (signature (restrictions [0 ]).parameters ) == 1
504+ and len (self .param_names ) > 1 ):
500505 restrictions = restrictions [0 ]
501506 if isinstance (restrictions , list ):
502507 for restriction in restrictions :
@@ -507,16 +512,7 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
507512 required_params = restriction [1 ]
508513 restriction = restriction [0 ]
509514 if callable (restriction ) and not isinstance (restriction , Constraint ):
510- # differentiate between old style monolithic with single 'p' argument and newer *args style
511- if len (signature (restriction ).parameters ) == 1 and len (self .param_names ) != 1 :
512- def restrictions_wrapper (* args ):
513- # raise ValueError(self.param_names, args, restriction, signature(restriction).parameters)
514- # return restriction(dict(zip(self.param_names, args)))
515- return check_instance_restrictions (restriction , dict (zip (self .param_names , args )), False )
516-
517- restriction = FunctionConstraint (restrictions_wrapper )
518- else :
519- restriction = FunctionConstraint (restriction , required_params )
515+ restriction = FunctionConstraint (restriction , required_params )
520516
521517 # add as a Constraint
522518 all_params_required = all (param_name in required_params for param_name in self .param_names )
@@ -537,6 +533,7 @@ def restrictions_wrapper(*args):
537533 elif callable (restrictions ):
538534
539535 def restrictions_wrapper (* args ):
536+ """Wrap old-style monolithic restrictions to work with multiple arguments."""
540537 return check_instance_restrictions (restrictions , dict (zip (self .param_names , args )), False )
541538
542539 parameter_space .addConstraint (FunctionConstraint (restrictions_wrapper ), self .param_names )
0 commit comments