@@ -269,19 +269,31 @@ def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, so
269269 # Define a bogus cost function
270270 costfunc = CostFunction (":" ) # bash no-op
271271
272- # build a dictionary of the restrictions, combined based on last parameter
273- print (self .restrictions )
272+ # add the Kernel Tuner default blocksize threads restrictions
274273 assert isinstance (self .restrictions , list )
274+ valid_block_size_names = list (
275+ block_size_name for block_size_name in block_size_names if block_size_name in self .param_names
276+ )
277+ if len (valid_block_size_names ) > 0 :
278+ # adding the default blocksize restriction requires recompilation because pyATF requires combined restrictions for the same parameter
279+ max_block_size_product = f"{ ' * ' .join (valid_block_size_names )} <= { max_threads } "
280+ restrictions = self ._modified_restrictions .copy () + [max_block_size_product ]
281+ self .restrictions = compile_restrictions (restrictions , self .tune_params , format = "pyatf" , try_to_constraint = False )
282+
283+ # build a dictionary of the restrictions, combined based on last parameter
275284 res_dict = dict ()
276285 registered_params = list ()
277286 registered_restrictions = list ()
278287 for param in self .tune_params .keys ():
279288 registered_params .append (param )
280- for index , (res , params ) in enumerate (self .restrictions ):
289+ for index , (res , params , source ) in enumerate (self .restrictions ):
281290 if index in registered_restrictions :
282291 continue
283292 if all (p in registered_params for p in params ):
284- res_dict [param ] = res
293+ if param in res_dict :
294+ raise KeyError (f"`{ param } ` is already in res_dict with `{ res_dict [param ][1 ]} `, can't add `{ source } `" )
295+ res_dict [param ] = (res , source )
296+ print (source , res , param , params )
285297 registered_restrictions .append (index )
286298
287299 # define the Tunable Parameters
@@ -291,12 +303,15 @@ def get_params():
291303 vi = get_interval (values )
292304 vals = Interval (vi [0 ], vi [1 ], vi [2 ]) if vi is not None and vi [2 ] != 0 else Set (* np .array (values ).flatten ())
293305 constraint = res_dict .get (key , None )
306+ constraint_source = None
307+ if constraint is not None :
308+ constraint , constraint_source = constraint
294309 # in case of a leftover monolithic restriction, append at the last parameter
295310 if index == len (self .tune_params ) - 1 and len (res_dict ) == 0 and len (self .restrictions ) == 1 :
296- res = self .restrictions [ 0 ] [0 ]
311+ res , params , source = self .restrictions [0 ]
297312 assert callable (res )
298313 constraint = res
299- params .append (TP (key , vals , constraint ))
314+ params .append (TP (key , vals , constraint , constraint_source ))
300315 return params
301316
302317 # tune
@@ -369,7 +384,7 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
369384 ):
370385 self ._modified_restrictions .append (max_block_size_product )
371386 if isinstance (self .restrictions , list ):
372- self .restrictions .append ((MaxProdConstraint (max_threads ), valid_block_size_names ))
387+ self .restrictions .append ((MaxProdConstraint (max_threads ), valid_block_size_names , None ))
373388
374389 # construct the parameter space with the constraints applied
375390 return parameter_space .getSolutionsAsListDict (order = self .param_names )
@@ -382,7 +397,7 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
382397
383398 # convert to a Constraint type if necessary
384399 if isinstance (restriction , tuple ):
385- restriction , required_params = restriction
400+ restriction , required_params , _ = restriction
386401 if callable (restriction ) and not isinstance (restriction , Constraint ):
387402 restriction = FunctionConstraint (restriction )
388403
0 commit comments