Skip to content

Commit 27226a9

Browse files
committed
Improvements to Simulated Annealing regarding invalid and error configurations
1 parent f2902b5 commit 27226a9

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

kernel_tuner/strategies/simulated_annealing.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from kernel_tuner.util import StopCriterionReached
7+
from kernel_tuner.util import StopCriterionReached, ErrorConfig
88
from kernel_tuner.searchspace import Searchspace
99
from kernel_tuner.strategies import common
1010
from kernel_tuner.strategies.common import CostFunc
@@ -18,7 +18,7 @@
1818

1919
def tune(searchspace: Searchspace, runner, tuning_options):
2020
# SA works with real parameter values and does not need scaling
21-
cost_func = CostFunc(searchspace, tuning_options, runner)
21+
cost_func = CostFunc(searchspace, tuning_options, runner, return_invalid=True)
2222

2323
# optimization parameters
2424
T, T_min, alpha, niter, constraint_aware = common.get_options(tuning_options.strategy_options, _options)
@@ -36,7 +36,7 @@ def tune(searchspace: Searchspace, runner, tuning_options):
3636

3737
# get random starting point and evaluate cost
3838
pos = generate_starting_point(searchspace, constraint_aware)
39-
old_cost = cost_func(pos, check_restrictions=False)
39+
old_cost = cost_func(pos, check_restrictions=not constraint_aware)
4040

4141
# main optimization loop
4242
stuck = 0
@@ -92,13 +92,12 @@ def tune(searchspace: Searchspace, runner, tuning_options):
9292

9393
def acceptance_prob(old_cost, new_cost, T, tuning_options):
9494
"""Annealing equation, with modifications to work towards a lower value."""
95-
error_val = sys.float_info.max
9695
res = 0.0
9796
# if start pos is not valid, always move
98-
if old_cost == error_val:
97+
if isinstance(old_cost, ErrorConfig):
9998
res = 1.0
10099
# if we have found a valid ps before, never move to nonvalid pos
101-
elif new_cost == error_val:
100+
elif isinstance(new_cost, ErrorConfig):
102101
res = 0.0
103102
# always move if new cost is better
104103
elif new_cost < old_cost:
@@ -117,7 +116,7 @@ def neighbor(pos, searchspace: Searchspace, constraint_aware=True):
117116

118117
def random_neighbor(pos, method):
119118
"""Helper method to return a random neighbor."""
120-
neighbors = searchspace.get_neighbor(pos, neighbor_method=method)
119+
neighbors = searchspace.get_neighbors(pos, neighbor_method=method)
121120
if not neighbors:
122121
return pos
123122
return random.choice(neighbors)

0 commit comments

Comments
 (0)