Skip to content

Commit 4394d13

Browse files
committed
Amended strategy tests to account for pyATF limitations
1 parent 3d12130 commit 4394d13

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

test/strategies/test_strategies.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
import numpy as np
44
import pytest
5+
from pathlib import Path
56

67
import kernel_tuner
78
from kernel_tuner.util import InvalidConfig
89
from kernel_tuner.interface import strategy_map
910

1011
from ..context import skip_if_no_bayesopt_botorch, skip_if_no_bayesopt_gpytorch
1112

12-
cache_filename = os.path.dirname(os.path.realpath(__file__)) + "/test_cache_file.json"
1313

1414
@pytest.fixture
1515
def vector_add():
@@ -51,7 +51,7 @@ def vector_add():
5151
strategies.append(s)
5252
@pytest.mark.parametrize('strategy', strategies)
5353
def test_strategies(vector_add, strategy):
54-
54+
cache_filename = Path(__file__).parent / "test_cache_file.json"
5555
options = dict(popsize=5, neighbor='adjacent')
5656

5757
print(f"testing {strategy}")
@@ -64,6 +64,17 @@ def test_strategies(vector_add, strategy):
6464

6565
restrictions = ["test_string == 'alg_2'", "test_bool == True", "test_mixed == 2.45"]
6666

67+
# pyATF can't handle non-number tune parameters, so we filter them out
68+
if strategy == "pyatf_strategies":
69+
tune_params = {
70+
"block_size_x": [128 + 64 * i for i in range(15)]
71+
}
72+
restrictions = []
73+
cache_filename = cache_filename.parent.parent / "test_cache_file.json"
74+
vector_add[-1] = tune_params
75+
76+
# run the tuning in simulation mode
77+
assert cache_filename.exists()
6778
results, _ = kernel_tuner.tune_kernel(*vector_add, restrictions=restrictions, strategy=strategy, strategy_options=filter_options,
6879
verbose=False, cache=cache_filename, simulation_mode=True)
6980

@@ -82,10 +93,6 @@ def test_strategies(vector_add, strategy):
8293
# check whether the returned dictionaries contain exactly the expected keys and the appropriate type
8394
expected_items = {
8495
'block_size_x': int,
85-
'test_string': str,
86-
'test_single': int,
87-
'test_bool': bool,
88-
'test_mixed': float,
8996
'time': (float, int),
9097
'times': list,
9198
'compile_time': (float, int),
@@ -95,6 +102,11 @@ def test_strategies(vector_add, strategy):
95102
'framework_time': (float, int),
96103
'timestamp': str
97104
}
105+
if strategy != "pyatf_strategies":
106+
expected_items['test_string'] = str
107+
expected_items['test_single'] = int
108+
expected_items['test_bool'] = bool
109+
expected_items['test_mixed'] = float
98110
for res in results:
99111
assert len(res) == len(expected_items)
100112
for expected_key, expected_type in expected_items.items():

0 commit comments

Comments
 (0)