4
4
5
5
import numpy as np
6
6
7
- from kernel_tuner .util import StopCriterionReached
7
+ from kernel_tuner .util import StopCriterionReached , ErrorConfig
8
8
from kernel_tuner .searchspace import Searchspace
9
9
from kernel_tuner .strategies import common
10
10
from kernel_tuner .strategies .common import CostFunc
18
18
19
19
def tune (searchspace : Searchspace , runner , tuning_options ):
20
20
# SA works with real parameter values and does not need scaling
21
- cost_func = CostFunc (searchspace , tuning_options , runner )
21
+ cost_func = CostFunc (searchspace , tuning_options , runner , return_invalid = True )
22
22
23
23
# optimization parameters
24
24
T , T_min , alpha , niter , constraint_aware = common .get_options (tuning_options .strategy_options , _options )
@@ -36,7 +36,7 @@ def tune(searchspace: Searchspace, runner, tuning_options):
36
36
37
37
# get random starting point and evaluate cost
38
38
pos = generate_starting_point (searchspace , constraint_aware )
39
- old_cost = cost_func (pos , check_restrictions = False )
39
+ old_cost = cost_func (pos , check_restrictions = not constraint_aware )
40
40
41
41
# main optimization loop
42
42
stuck = 0
@@ -92,13 +92,12 @@ def tune(searchspace: Searchspace, runner, tuning_options):
92
92
93
93
def acceptance_prob (old_cost , new_cost , T , tuning_options ):
94
94
"""Annealing equation, with modifications to work towards a lower value."""
95
- error_val = sys .float_info .max
96
95
res = 0.0
97
96
# if start pos is not valid, always move
98
- if old_cost == error_val :
97
+ if isinstance ( old_cost , ErrorConfig ) :
99
98
res = 1.0
100
99
# if we have found a valid ps before, never move to nonvalid pos
101
- elif new_cost == error_val :
100
+ elif isinstance ( new_cost , ErrorConfig ) :
102
101
res = 0.0
103
102
# always move if new cost is better
104
103
elif new_cost < old_cost :
@@ -117,7 +116,7 @@ def neighbor(pos, searchspace: Searchspace, constraint_aware=True):
117
116
118
117
def random_neighbor (pos , method ):
119
118
"""Helper method to return a random neighbor."""
120
- neighbors = searchspace .get_neighbor (pos , neighbor_method = method )
119
+ neighbors = searchspace .get_neighbors (pos , neighbor_method = method )
121
120
if not neighbors :
122
121
return pos
123
122
return random .choice (neighbors )
0 commit comments