Skip to content

Commit 4d19ee8

Browse files
committed
Improvements to code style
1 parent 739f954 commit 4d19ee8

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

kernel_tuner/searchspace.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,12 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
496496
def __add_restrictions(self, parameter_space: Problem) -> Problem:
497497
"""Add the user-specified restrictions as constraints on the parameter space."""
498498
restrictions = deepcopy(self.restrictions)
499-
if len(restrictions) == 1 and not isinstance(restrictions[0], (Constraint, FunctionConstraint, str)) and callable(restrictions[0]) and len(signature(restrictions[0]).parameters) == 1:
499+
# differentiate between old style monolithic with single 'p' argument and newer *args style
500+
if (len(restrictions) == 1
501+
and not isinstance(restrictions[0], (Constraint, FunctionConstraint, str))
502+
and callable(restrictions[0])
503+
and len(signature(restrictions[0]).parameters) == 1
504+
and len(self.param_names) > 1):
500505
restrictions = restrictions[0]
501506
if isinstance(restrictions, list):
502507
for restriction in restrictions:
@@ -507,16 +512,7 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
507512
required_params = restriction[1]
508513
restriction = restriction[0]
509514
if callable(restriction) and not isinstance(restriction, Constraint):
510-
# differentiate between old style monolithic with single 'p' argument and newer *args style
511-
if len(signature(restriction).parameters) == 1 and len(self.param_names) != 1:
512-
def restrictions_wrapper(*args):
513-
# raise ValueError(self.param_names, args, restriction, signature(restriction).parameters)
514-
# return restriction(dict(zip(self.param_names, args)))
515-
return check_instance_restrictions(restriction, dict(zip(self.param_names, args)), False)
516-
517-
restriction = FunctionConstraint(restrictions_wrapper)
518-
else:
519-
restriction = FunctionConstraint(restriction, required_params)
515+
restriction = FunctionConstraint(restriction, required_params)
520516

521517
# add as a Constraint
522518
all_params_required = all(param_name in required_params for param_name in self.param_names)
@@ -537,6 +533,7 @@ def restrictions_wrapper(*args):
537533
elif callable(restrictions):
538534

539535
def restrictions_wrapper(*args):
536+
"""Wrap old-style monolithic restrictions to work with multiple arguments."""
540537
return check_instance_restrictions(restrictions, dict(zip(self.param_names, args)), False)
541538

542539
parameter_space.addConstraint(FunctionConstraint(restrictions_wrapper), self.param_names)

0 commit comments

Comments
 (0)