1818)
1919
2020from kernel_tuner .util import check_restrictions as check_instance_restrictions
21- from kernel_tuner .util import compile_restrictions , default_block_size_names
21+ from kernel_tuner .util import compile_restrictions , default_block_size_names , get_interval
2222
2323supported_neighbor_methods = ["strictly-adjacent" , "adjacent" , "Hamming" ]
2424
@@ -262,7 +262,7 @@ def all_smt(formula, keys) -> list:
262262
263263 def __build_searchspace_pyATF (self , block_size_names : list , max_threads : int , solver : Solver ):
264264 """Builds the searchspace using pyATF."""
265- from pyatf import TP , Set , Tuner
265+ from pyatf import TP , Set , Interval , Tuner
266266 from pyatf .cost_functions .generic import CostFunction
267267 from pyatf .search_techniques import Exhaustive
268268
@@ -289,7 +289,8 @@ def get_params():
289289 params = list ()
290290 print ("get_params" )
291291 for index , (key , values ) in enumerate (self .tune_params .items ()):
292- vals = Set (* np .array (values ).flatten ()) # TODO check if can be interval
292+ vi = get_interval (values )
293+ vals = Interval (vi [0 ], vi [1 ], vi [2 ]) if vi is not None else Set (* np .array (values ).flatten ())
293294 constraint = res_dict .get (key , None )
294295 if len (res_dict ) == 0 and index == len (self .tune_params ) - 1 and constraint is None :
295296 res = self .restrictions [0 ][0 ]
0 commit comments