Skip to content

Commit 160a81f

Browse files
committed
Minor bug fixes
1 parent 33c291d commit 160a81f

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

kernel_tuner/util.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,12 +250,15 @@ def check_restriction(restrict, params: dict) -> bool:
250250
# if it's a tuple, use only the parameters in the second argument to call the restriction
251251
elif (
252252
isinstance(restrict, tuple)
253-
and len(restrict) == 2
253+
and (len(restrict) == 2 or len(restrict) == 3)
254254
and callable(restrict[0])
255255
and isinstance(restrict[1], (list, tuple))
256256
):
257257
# unpack the tuple
258-
restrict, selected_params = restrict
258+
if len(restrict) == 2:
259+
restrict, selected_params = restrict
260+
else:
261+
restrict, selected_params, source = restrict
259262
# look up the selected parameters and their value
260263
selected_params = dict((key, params[key]) for key in selected_params)
261264
# call the restriction
@@ -1061,14 +1064,14 @@ def to_equality_constraint(
10611064
finalized_constraint = to_equality_constraint(parsed_restriction, params_used)
10621065
if finalized_constraint is None:
10631066
# we must turn it into a general function
1064-
if format.lower() == "pyatf":
1067+
if format is not None and format.lower() == "pyatf":
10651068
finalized_constraint = parsed_restriction
10661069
else:
10671070
finalized_constraint = f"def r({', '.join(params_used)}): return {parsed_restriction} \n"
10681071
parsed_restrictions.append((finalized_constraint, params_used))
10691072

10701073
# if pyATF, restrictions that are set on the same parameter must be combined into one
1071-
if format.lower() == "pyatf":
1074+
if format is not None and format.lower() == "pyatf":
10721075
res_dict = dict()
10731076
registered_params = list()
10741077
registered_restrictions = list()

0 commit comments

Comments
 (0)