Skip to content

Commit 8c416ad

Browse files
committed
pyATF restrictions needed to be recombined in cases where multiple restrictions are defined on the same tunable parameter
1 parent f2a0f03 commit 8c416ad

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

kernel_tuner/searchspace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,12 +287,12 @@ def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, so
287287
# define the Tunable Parameters
288288
def get_params():
289289
params = list()
290-
print("get_params")
291290
for index, (key, values) in enumerate(self.tune_params.items()):
292291
vi = get_interval(values)
293292
vals = Interval(vi[0], vi[1], vi[2]) if vi is not None else Set(*np.array(values).flatten())
294293
constraint = res_dict.get(key, None)
295-
if len(res_dict) == 0 and index == len(self.tune_params) - 1 and constraint is None:
294+
# in case of a leftover monolithic restriction, append at the last parameter
295+
if index == len(self.tune_params) - 1 and len(res_dict) == 0 and len(self.restrictions) == 1:
296296
res = self.restrictions[0][0]
297297
assert callable(res)
298298
constraint = res

kernel_tuner/util.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,8 +1052,35 @@ def to_equality_constraint(
10521052
finalized_constraint = to_equality_constraint(parsed_restriction, params_used)
10531053
if finalized_constraint is None:
10541054
# we must turn it into a general function
1055-
finalized_constraint = f"def r({', '.join(params_used)}): return {parsed_restriction} \n"
1055+
if format.lower() == "pyatf":
1056+
finalized_constraint = parsed_restriction
1057+
else:
1058+
finalized_constraint = f"def r({', '.join(params_used)}): return {parsed_restriction} \n"
10561059
parsed_restrictions.append((finalized_constraint, params_used))
1060+
1061+
# if pyATF, restrictions that are set on the same parameter must be combined into one
1062+
if format.lower() == "pyatf":
1063+
res_dict = dict()
1064+
registered_params = list()
1065+
registered_restrictions = list()
1066+
for param in tune_params.keys():
1067+
registered_params.append(param)
1068+
for index, (res, params) in enumerate(parsed_restrictions):
1069+
if index in registered_restrictions:
1070+
continue
1071+
if all(p in registered_params for p in params):
1072+
if param not in res_dict:
1073+
res_dict[param] = (list(), list())
1074+
res_dict[param][0].append(res)
1075+
res_dict[param][1].extend(params)
1076+
registered_restrictions.append(index)
1077+
# combine multiple restrictions into one
1078+
parsed_restrictions_pyatf = list()
1079+
for res_tuple in res_dict.values():
1080+
res, params_used = res_tuple
1081+
params_used = list(dict.fromkeys(params_used)) # param_used should only contain unique, dict preserves order
1082+
parsed_restrictions_pyatf.append((f"def r({', '.join(params_used)}): return {' and '.join(res)} \n", params_used))
1083+
return parsed_restrictions_pyatf
10571084
else:
10581085
# create one monolithic function
10591086
parsed_restrictions = ") and (".join(

0 commit comments

Comments
 (0)