@@ -269,19 +269,31 @@ def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, so
269
269
# Define a bogus cost function
270
270
costfunc = CostFunction (":" ) # bash no-op
271
271
272
- # build a dictionary of the restrictions, combined based on last parameter
273
- print (self .restrictions )
272
+ # add the Kernel Tuner default blocksize threads restrictions
274
273
assert isinstance (self .restrictions , list )
274
+ valid_block_size_names = list (
275
+ block_size_name for block_size_name in block_size_names if block_size_name in self .param_names
276
+ )
277
+ if len (valid_block_size_names ) > 0 :
278
+ # adding the default blocksize restriction requires recompilation because pyATF requires combined restrictions for the same parameter
279
+ max_block_size_product = f"{ ' * ' .join (valid_block_size_names )} <= { max_threads } "
280
+ restrictions = self ._modified_restrictions .copy () + [max_block_size_product ]
281
+ self .restrictions = compile_restrictions (restrictions , self .tune_params , format = "pyatf" , try_to_constraint = False )
282
+
283
+ # build a dictionary of the restrictions, combined based on last parameter
275
284
res_dict = dict ()
276
285
registered_params = list ()
277
286
registered_restrictions = list ()
278
287
for param in self .tune_params .keys ():
279
288
registered_params .append (param )
280
- for index , (res , params ) in enumerate (self .restrictions ):
289
+ for index , (res , params , source ) in enumerate (self .restrictions ):
281
290
if index in registered_restrictions :
282
291
continue
283
292
if all (p in registered_params for p in params ):
284
- res_dict [param ] = res
293
+ if param in res_dict :
294
+ raise KeyError (f"`{ param } ` is already in res_dict with `{ res_dict [param ][1 ]} `, can't add `{ source } `" )
295
+ res_dict [param ] = (res , source )
296
+ print (source , res , param , params )
285
297
registered_restrictions .append (index )
286
298
287
299
# define the Tunable Parameters
@@ -291,12 +303,15 @@ def get_params():
291
303
vi = get_interval (values )
292
304
vals = Interval (vi [0 ], vi [1 ], vi [2 ]) if vi is not None and vi [2 ] != 0 else Set (* np .array (values ).flatten ())
293
305
constraint = res_dict .get (key , None )
306
+ constraint_source = None
307
+ if constraint is not None :
308
+ constraint , constraint_source = constraint
294
309
# in case of a leftover monolithic restriction, append at the last parameter
295
310
if index == len (self .tune_params ) - 1 and len (res_dict ) == 0 and len (self .restrictions ) == 1 :
296
- res = self .restrictions [ 0 ] [0 ]
311
+ res , params , source = self .restrictions [0 ]
297
312
assert callable (res )
298
313
constraint = res
299
- params .append (TP (key , vals , constraint ))
314
+ params .append (TP (key , vals , constraint , constraint_source ))
300
315
return params
301
316
302
317
# tune
@@ -369,7 +384,7 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
369
384
):
370
385
self ._modified_restrictions .append (max_block_size_product )
371
386
if isinstance (self .restrictions , list ):
372
- self .restrictions .append ((MaxProdConstraint (max_threads ), valid_block_size_names ))
387
+ self .restrictions .append ((MaxProdConstraint (max_threads ), valid_block_size_names , None ))
373
388
374
389
# construct the parameter space with the constraints applied
375
390
return parameter_space .getSolutionsAsListDict (order = self .param_names )
@@ -382,7 +397,7 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
382
397
383
398
# convert to a Constraint type if necessary
384
399
if isinstance (restriction , tuple ):
385
- restriction , required_params = restriction
400
+ restriction , required_params , _ = restriction
386
401
if callable (restriction ) and not isinstance (restriction , Constraint ):
387
402
restriction = FunctionConstraint (restriction )
388
403
0 commit comments