Skip to content

Commit 235de0d

Browse files
ben-arnaofmfn
authored andcommitted
add kappa decay
Update util.py Update util.py
1 parent af48b03 commit 235de0d

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

bayes_opt/bayesian_optimization.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ def maximize(self,
154154
n_iter=25,
155155
acq='ucb',
156156
kappa=2.576,
157+
kappa_decay=1,
158+
kappa_decay_delay=0,
157159
xi=0.0,
158160
**gp_params):
159161
"""Mazimize your function"""
@@ -162,12 +164,17 @@ def maximize(self,
162164
self._prime_queue(init_points)
163165
self.set_gp_params(**gp_params)
164166

165-
util = UtilityFunction(kind=acq, kappa=kappa, xi=xi)
167+
util = UtilityFunction(kind=acq,
168+
kappa=kappa,
169+
xi=xi,
170+
kappa_decay=kappa_decay,
171+
kappa_decay_delay=kappa_decay_delay)
166172
iteration = 0
167173
while not self._queue.empty or iteration < n_iter:
168174
try:
169175
x_probe = next(self._queue)
170176
except StopIteration:
177+
util.update_params()
171178
x_probe = self.suggest(util)
172179
iteration += 1
173180

bayes_opt/util.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,15 @@ class UtilityFunction(object):
7676
An object to compute the acquisition functions.
7777
"""
7878

79-
def __init__(self, kind, kappa, xi):
80-
"""
81-
If UCB is to be used, a constant kappa is needed.
82-
"""
79+
def __init__(self, kind, kappa, xi, kappa_decay=1, kappa_decay_delay=0):
80+
8381
self.kappa = kappa
82+
self._kappa_decay = kappa_decay
83+
self._kappa_decay_delay = kappa_decay_delay
8484

8585
self.xi = xi
86+
87+
self._iters_counter = 0
8688

8789
if kind not in ['ucb', 'ei', 'poi']:
8890
err = "The utility function " \
@@ -92,6 +94,12 @@ def __init__(self, kind, kappa, xi):
9294
else:
9395
self.kind = kind
9496

97+
def update_params(self):
98+
self._iters_counter += 1
99+
100+
if self._kappa_decay < 1 and self._iters_counter > self._kappa_decay_delay:
101+
self.kappa *= self._kappa_decay
102+
95103
def utility(self, x, gp, y_max):
96104
if self.kind == 'ucb':
97105
return self._ucb(x, gp, self.kappa)

0 commit comments

Comments
 (0)