Skip to content

Commit 31ad0c8

Browse files
committed
Added per-parameter restrictions registration for pyATF
1 parent 3be5685 commit 31ad0c8

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

kernel_tuner/searchspace.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -269,25 +269,29 @@ def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, so
269269
# Define a bogus cost function
270270
costfunc = CostFunction(":") # bash no-op
271271

272-
# build a dictionary of the restrictions
272+
# build a dictionary of the restrictions, combined based on last parameter
273273
print(self.restrictions)
274274
assert isinstance(self.restrictions, list)
275275
res_dict = dict()
276-
# for res, params in self.restrictions:
277-
# key = params[0]
278-
# assert callable(res)
279-
# assert key not in res_dict
280-
# res_dict[key] = res
281-
# print(res_dict)
276+
registered_params = list()
277+
registered_restrictions = list()
278+
for param in self.tune_params.keys():
279+
registered_params.append(param)
280+
for index, (res, params) in enumerate(self.restrictions):
281+
if index in registered_restrictions:
282+
continue
283+
if all(p in registered_params for p in params):
284+
res_dict[param] = res
285+
registered_restrictions.append(index)
282286

283287
# define the Tunable Parameters
284288
def get_params():
285289
params = list()
286290
print("get_params")
287291
for index, (key, values) in enumerate(self.tune_params.items()):
288-
vals = Set(*values.flatten()) # TODO check if can be interval
292+
vals = Set(*np.array(values).flatten()) # TODO check if can be interval
289293
constraint = res_dict.get(key, None)
290-
if index == len(self.tune_params) - 1 and constraint is None:
294+
if len(res_dict) == 0 and index == len(self.tune_params) - 1 and constraint is None:
291295
res = self.restrictions[0][0]
292296
assert callable(res)
293297
constraint = res
@@ -296,7 +300,7 @@ def get_params():
296300

297301
# tune
298302
_, _, tuning_data = (
299-
Tuner().silent(True).tuning_parameters(*get_params()).search_technique(Exhaustive()).tune(costfunc)
303+
Tuner().verbosity(0).tuning_parameters(*get_params()).search_technique(Exhaustive()).tune(costfunc)
300304
)
301305

302306
# transform the result into a list of parameter configurations for validation

kernel_tuner/util.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,8 +1007,6 @@ def to_equality_constraint(
10071007
return ValueError(f"Not possible: comparator should be '==' or '!=', is {comparator}")
10081008
return None
10091009

1010-
# TODO if format == "pyatf", combine based on last parameter
1011-
10121010
# create the parsed restrictions
10131011
if monolithic is False:
10141012
# split into multiple restrictions where possible
@@ -1086,8 +1084,6 @@ def compile_restrictions(
10861084
restrictions_str, tune_params, monolithic=monolithic, format=format, try_to_constraint=try_to_constraint
10871085
)
10881086

1089-
# TODO if format == "pyatf", return a dictionary instead of a list
1090-
10911087
# compile the parsed restrictions into a function
10921088
compiled_restrictions: list[tuple] = list()
10931089
for restriction, params_used in parsed_restrictions:

0 commit comments

Comments
 (0)