Skip to content

Commit 41cd741

Browse files
committed
Restriction sources are passed back after compilation to aid pyATF
1 parent 5e77fe4 commit 41cd741

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

kernel_tuner/util.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,7 @@ def to_equality_constraint(
10631063
res_dict = dict()
10641064
registered_params = list()
10651065
registered_restrictions = list()
1066+
parsed_restrictions_pyatf = list()
10661067
for param in tune_params.keys():
10671068
registered_params.append(param)
10681069
for index, (res, params) in enumerate(parsed_restrictions):
@@ -1075,12 +1076,11 @@ def to_equality_constraint(
10751076
res_dict[param][1].extend(params)
10761077
registered_restrictions.append(index)
10771078
# combine multiple restrictions into one
1078-
parsed_restrictions_pyatf = list()
10791079
for res_tuple in res_dict.values():
10801080
res, params_used = res_tuple
10811081
params_used = list(dict.fromkeys(params_used)) # param_used should only contain unique, dict preserves order
1082-
parsed_restrictions_pyatf.append((f"def r({', '.join(params_used)}): return {' and '.join(res)} \n", params_used))
1083-
return parsed_restrictions_pyatf
1082+
parsed_restrictions_pyatf.append((f"def r({', '.join(params_used)}): return ({') and ('.join(res)}) \n", params_used))
1083+
parsed_restrictions = parsed_restrictions_pyatf
10841084
else:
10851085
# create one monolithic function
10861086
parsed_restrictions = ") and (".join(
@@ -1114,8 +1114,8 @@ def to_equality_constraint(
11141114

11151115
def compile_restrictions(
11161116
restrictions: list, tune_params: dict, monolithic=False, format=None, try_to_constraint=True
1117-
) -> list[tuple[Union[str, Constraint, FunctionType], list[str]]]:
1118-
"""Parses restrictions from a list of strings into a list of strings, Functions, or Constraints (if `try_to_constraint`) and parameters used, or a single Function if monolithic is true."""
1117+
) -> list[tuple[Union[str, Constraint, FunctionType], list[str], Union[str, None]]]:
1118+
"""Parses restrictions from a list of strings into a list of strings, Functions, or Constraints (if `try_to_constraint`) and parameters used and source, or a single Function if monolithic is true."""
11191119
# filter the restrictions to get only the strings
11201120
restrictions_str, restrictions_ignore = [], []
11211121
for r in restrictions:
@@ -1135,10 +1135,10 @@ def compile_restrictions(
11351135
# if it's a string, parse it to a function
11361136
code_object = compile(restriction, "<string>", "exec")
11371137
func = FunctionType(code_object.co_consts[0], globals())
1138-
compiled_restrictions.append((func, params_used))
1138+
compiled_restrictions.append((func, params_used, restriction))
11391139
elif isinstance(restriction, Constraint):
11401140
# otherwise it already is a Constraint, pass it directly
1141-
compiled_restrictions.append((restriction, params_used))
1141+
compiled_restrictions.append((restriction, params_used, None))
11421142
else:
11431143
raise ValueError(f"Restriction {restriction} is neither a string or Constraint {type(restriction)}")
11441144

@@ -1150,9 +1150,10 @@ def compile_restrictions(
11501150
noncompiled_restrictions = []
11511151
for r in restrictions_ignore:
11521152
if isinstance(r, tuple) and len(r) == 2 and isinstance(r[1], (list, tuple)):
1153-
noncompiled_restrictions.append(r)
1153+
restriction, params_used = r
1154+
noncompiled_restrictions.append((restriction, params_used, restriction))
11541155
else:
1155-
noncompiled_restrictions.append((r, ()))
1156+
noncompiled_restrictions.append((r, [], r))
11561157
return noncompiled_restrictions + compiled_restrictions
11571158

11581159

0 commit comments

Comments
 (0)