Skip to content

Commit f2a0f03

Browse files
committed
Generating intervals for pyATF where possible
1 parent 31ad0c8 commit f2a0f03

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

kernel_tuner/searchspace.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919

2020
from kernel_tuner.util import check_restrictions as check_instance_restrictions
21-
from kernel_tuner.util import compile_restrictions, default_block_size_names
21+
from kernel_tuner.util import compile_restrictions, default_block_size_names, get_interval
2222

2323
supported_neighbor_methods = ["strictly-adjacent", "adjacent", "Hamming"]
2424

@@ -262,7 +262,7 @@ def all_smt(formula, keys) -> list:
262262

263263
def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, solver: Solver):
264264
"""Builds the searchspace using pyATF."""
265-
from pyatf import TP, Set, Tuner
265+
from pyatf import TP, Set, Interval, Tuner
266266
from pyatf.cost_functions.generic import CostFunction
267267
from pyatf.search_techniques import Exhaustive
268268

@@ -289,7 +289,8 @@ def get_params():
289289
params = list()
290290
print("get_params")
291291
for index, (key, values) in enumerate(self.tune_params.items()):
292-
vals = Set(*np.array(values).flatten()) # TODO check if can be interval
292+
vi = get_interval(values)
293+
vals = Interval(vi[0], vi[1], vi[2]) if vi is not None else Set(*np.array(values).flatten())
293294
constraint = res_dict.get(key, None)
294295
if len(res_dict) == 0 and index == len(self.tune_params) - 1 and constraint is None:
295296
res = self.restrictions[0][0]

kernel_tuner/util.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,23 @@ def get_instance_string(params):
439439
return "_".join([str(i) for i in params.values()])
440440

441441

442+
def get_interval(a: list):
443+
"""Checks if an array can be an interval. Returns (start, end, step) if interval, otherwise None."""
444+
if not all(isinstance(e, (int, float)) for e in a):
445+
return None
446+
a_min = min(a)
447+
a_max = max(a)
448+
if len(a) <= 2:
449+
return (a_min, a_max, a_max-a_min)
450+
# determine the first step size
451+
step = a[1]-a_min
452+
# for each element, the step size should be equal to the first step
453+
for i, e in enumerate(a):
454+
if e-a[i-1] != step:
455+
return None
456+
return (a_min, a_max, step)
457+
458+
442459
def get_kernel_string(kernel_source, params=None):
443460
"""Retrieve the kernel source and return as a string.
444461

0 commit comments

Comments
 (0)