2
2
3
3
import numpy as np
4
4
import pytest
5
+ from pathlib import Path
5
6
6
7
import kernel_tuner
7
8
from kernel_tuner .util import InvalidConfig
8
9
from kernel_tuner .interface import strategy_map
9
10
10
11
from ..context import skip_if_no_bayesopt_botorch , skip_if_no_bayesopt_gpytorch
11
12
12
- cache_filename = os .path .dirname (os .path .realpath (__file__ )) + "/test_cache_file.json"
13
13
14
14
@pytest .fixture
15
15
def vector_add ():
@@ -51,7 +51,7 @@ def vector_add():
51
51
strategies .append (s )
52
52
@pytest .mark .parametrize ('strategy' , strategies )
53
53
def test_strategies (vector_add , strategy ):
54
-
54
+ cache_filename = Path ( __file__ ). parent / "test_cache_file.json"
55
55
options = dict (popsize = 5 , neighbor = 'adjacent' )
56
56
57
57
print (f"testing { strategy } " )
@@ -64,6 +64,17 @@ def test_strategies(vector_add, strategy):
64
64
65
65
restrictions = ["test_string == 'alg_2'" , "test_bool == True" , "test_mixed == 2.45" ]
66
66
67
+ # pyATF can't handle non-number tune parameters, so we filter them out
68
+ if strategy == "pyatf_strategies" :
69
+ tune_params = {
70
+ "block_size_x" : [128 + 64 * i for i in range (15 )]
71
+ }
72
+ restrictions = []
73
+ cache_filename = cache_filename .parent .parent / "test_cache_file.json"
74
+ vector_add [- 1 ] = tune_params
75
+
76
+ # run the tuning in simulation mode
77
+ assert cache_filename .exists ()
67
78
results , _ = kernel_tuner .tune_kernel (* vector_add , restrictions = restrictions , strategy = strategy , strategy_options = filter_options ,
68
79
verbose = False , cache = cache_filename , simulation_mode = True )
69
80
@@ -82,10 +93,6 @@ def test_strategies(vector_add, strategy):
82
93
# check whether the returned dictionaries contain exactly the expected keys and the appropriate type
83
94
expected_items = {
84
95
'block_size_x' : int ,
85
- 'test_string' : str ,
86
- 'test_single' : int ,
87
- 'test_bool' : bool ,
88
- 'test_mixed' : float ,
89
96
'time' : (float , int ),
90
97
'times' : list ,
91
98
'compile_time' : (float , int ),
@@ -95,6 +102,11 @@ def test_strategies(vector_add, strategy):
95
102
'framework_time' : (float , int ),
96
103
'timestamp' : str
97
104
}
105
+ if strategy != "pyatf_strategies" :
106
+ expected_items ['test_string' ] = str
107
+ expected_items ['test_single' ] = int
108
+ expected_items ['test_bool' ] = bool
109
+ expected_items ['test_mixed' ] = float
98
110
for res in results :
99
111
assert len (res ) == len (expected_items )
100
112
for expected_key , expected_type in expected_items .items ():
0 commit comments