18
18
)
19
19
20
20
from kernel_tuner .util import check_restrictions as check_instance_restrictions
21
- from kernel_tuner .util import compile_restrictions , default_block_size_names
21
+ from kernel_tuner .util import compile_restrictions , default_block_size_names , get_interval
22
22
23
23
supported_neighbor_methods = ["strictly-adjacent" , "adjacent" , "Hamming" ]
24
24
@@ -262,7 +262,7 @@ def all_smt(formula, keys) -> list:
262
262
263
263
def __build_searchspace_pyATF (self , block_size_names : list , max_threads : int , solver : Solver ):
264
264
"""Builds the searchspace using pyATF."""
265
- from pyatf import TP , Set , Tuner
265
+ from pyatf import TP , Set , Interval , Tuner
266
266
from pyatf .cost_functions .generic import CostFunction
267
267
from pyatf .search_techniques import Exhaustive
268
268
@@ -289,7 +289,8 @@ def get_params():
289
289
params = list ()
290
290
print ("get_params" )
291
291
for index , (key , values ) in enumerate (self .tune_params .items ()):
292
- vals = Set (* np .array (values ).flatten ()) # TODO check if can be interval
292
+ vi = get_interval (values )
293
+ vals = Interval (vi [0 ], vi [1 ], vi [2 ]) if vi is not None else Set (* np .array (values ).flatten ())
293
294
constraint = res_dict .get (key , None )
294
295
if len (res_dict ) == 0 and index == len (self .tune_params ) - 1 and constraint is None :
295
296
res = self .restrictions [0 ][0 ]
0 commit comments