Skip to content

Commit 57d853a

Browse files
Merge pull request #315 from KernelTuner/fix-issue-314
Change `CostFunc` to return `+inf` when `objective_higher_is_better`
2 parents 552973d + 68b9809 commit 57d853a

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

kernel_tuner/strategies/common.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from time import perf_counter
44

55
import numpy as np
6+
import numbers
67

78
from kernel_tuner import util
89
from kernel_tuner.searchspace import Searchspace
@@ -109,8 +110,15 @@ def __call__(self, x, check_restrictions=True):
109110
self.runner.last_strategy_start_time = perf_counter()
110111

111112
# get numerical return value, taking optimization direction into account
112-
return_value = result[self.tuning_options.objective] or sys.float_info.max
113-
return_value = return_value if not self.tuning_options.objective_higher_is_better else -return_value
113+
return_value = result[self.tuning_options.objective]
114+
115+
if isinstance(return_value, numbers.Number):
116+
if self.tuning_options.objective_higher_is_better:
117+
# flip the sign if higher means better
118+
return_value = -return_value
119+
else:
120+
# this is not a valid configuration, just return max
121+
return_value = sys.float_info.max
114122

115123
return return_value
116124

kernel_tuner/strategies/simulated_annealing.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,24 @@ def tune(searchspace: Searchspace, runner, tuning_options):
8787

8888
def acceptance_prob(old_cost, new_cost, T, tuning_options):
8989
"""Annealing equation, with modifications to work towards a lower value."""
90-
error_val = sys.float_info.max if not tuning_options.objective_higher_is_better else -sys.float_info.max
90+
error_val = sys.float_info.max
91+
res = 0.0
9192
# if start pos is not valid, always move
9293
if old_cost == error_val:
93-
return 1.0
94+
res = 1.0
9495
# if we have found a valid ps before, never move to nonvalid pos
95-
if new_cost == error_val:
96-
return 0.0
96+
elif new_cost == error_val:
97+
res = 0.0
9798
# always move if new cost is better
98-
if new_cost < old_cost:
99-
return 1.0
99+
elif new_cost < old_cost:
100+
res = 1.0
100101
# maybe move if old cost is better than new cost depending on T and random value
101-
if tuning_options.objective_higher_is_better:
102-
return np.exp(((new_cost-old_cost)/new_cost)/T)
103-
return np.exp(((old_cost-new_cost)/old_cost)/T)
102+
else:
103+
if tuning_options.objective_higher_is_better:
104+
res = np.exp(((new_cost-old_cost)/new_cost)/T)
105+
else:
106+
res = np.exp(((old_cost-new_cost)/old_cost)/T)
107+
return res
104108

105109

106110
def neighbor(pos, searchspace: Searchspace):

0 commit comments

Comments
 (0)