Skip to content

Commit 39b5f01

Browse files
committed
Added default blocksize restriction for pyATF, as well as raising an error on multiple restrictions defined on the same tunable parameter, and passing of restriction source
1 parent 41cd741 commit 39b5f01

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

kernel_tuner/searchspace.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -269,19 +269,31 @@ def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, so
269269
# Define a bogus cost function
270270
costfunc = CostFunction(":") # bash no-op
271271

272-
# build a dictionary of the restrictions, combined based on last parameter
273-
print(self.restrictions)
272+
# add the Kernel Tuner default blocksize threads restrictions
274273
assert isinstance(self.restrictions, list)
274+
valid_block_size_names = list(
275+
block_size_name for block_size_name in block_size_names if block_size_name in self.param_names
276+
)
277+
if len(valid_block_size_names) > 0:
278+
# adding the default blocksize restriction requires recompilation because pyATF requires combined restrictions for the same parameter
279+
max_block_size_product = f"{' * '.join(valid_block_size_names)} <= {max_threads}"
280+
restrictions = self._modified_restrictions.copy() + [max_block_size_product]
281+
self.restrictions = compile_restrictions(restrictions, self.tune_params, format="pyatf", try_to_constraint=False)
282+
283+
# build a dictionary of the restrictions, combined based on last parameter
275284
res_dict = dict()
276285
registered_params = list()
277286
registered_restrictions = list()
278287
for param in self.tune_params.keys():
279288
registered_params.append(param)
280-
for index, (res, params) in enumerate(self.restrictions):
289+
for index, (res, params, source) in enumerate(self.restrictions):
281290
if index in registered_restrictions:
282291
continue
283292
if all(p in registered_params for p in params):
284-
res_dict[param] = res
293+
if param in res_dict:
294+
raise KeyError(f"`{param}` is already in res_dict with `{res_dict[param][1]}`, can't add `{source}`")
295+
res_dict[param] = (res, source)
296+
print(source, res, param, params)
285297
registered_restrictions.append(index)
286298

287299
# define the Tunable Parameters
@@ -291,12 +303,15 @@ def get_params():
291303
vi = get_interval(values)
292304
vals = Interval(vi[0], vi[1], vi[2]) if vi is not None and vi[2] != 0 else Set(*np.array(values).flatten())
293305
constraint = res_dict.get(key, None)
306+
constraint_source = None
307+
if constraint is not None:
308+
constraint, constraint_source = constraint
294309
# in case of a leftover monolithic restriction, append at the last parameter
295310
if index == len(self.tune_params) - 1 and len(res_dict) == 0 and len(self.restrictions) == 1:
296-
res = self.restrictions[0][0]
311+
res, params, source = self.restrictions[0]
297312
assert callable(res)
298313
constraint = res
299-
params.append(TP(key, vals, constraint))
314+
params.append(TP(key, vals, constraint, constraint_source))
300315
return params
301316

302317
# tune
@@ -369,7 +384,7 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
369384
):
370385
self._modified_restrictions.append(max_block_size_product)
371386
if isinstance(self.restrictions, list):
372-
self.restrictions.append((MaxProdConstraint(max_threads), valid_block_size_names))
387+
self.restrictions.append((MaxProdConstraint(max_threads), valid_block_size_names, None))
373388

374389
# construct the parameter space with the constraints applied
375390
return parameter_space.getSolutionsAsListDict(order=self.param_names)
@@ -382,7 +397,7 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
382397

383398
# convert to a Constraint type if necessary
384399
if isinstance(restriction, tuple):
385-
restriction, required_params = restriction
400+
restriction, required_params, _ = restriction
386401
if callable(restriction) and not isinstance(restriction, Constraint):
387402
restriction = FunctionConstraint(restriction)
388403

0 commit comments

Comments
 (0)