Skip to content

Commit 376f205

Browse files
committed
Added observer pattern to allow callbacks
1 parent 317f735 commit 376f205

File tree

6 files changed

+155
-72
lines changed

6 files changed

+155
-72
lines changed

bayes_opt/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from .bayesian_optimization import BayesianOptimization
1+
from .bayesian_optimization import BayesianOptimization, Events
22
from .helpers import UtilityFunction
3+
from .observer import Observer
34

4-
__all__ = ["BayesianOptimization", "UtilityFunction"]
5+
__all__ = ["BayesianOptimization", "UtilityFunction", "Events", "Observer"]

bayes_opt/bayesian_optimization.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from sklearn.gaussian_process.kernels import Matern
88
from .helpers import (UtilityFunction, PrintLog, acq_max, ensure_rng)
99
from .target_space import TargetSpace
10+
from .observer import Observable
1011

1112

12-
class BayesianOptimization(object):
13+
class BayesianOptimization(Observable):
1314

1415
def __init__(self, f, pbounds, random_state=None, verbose=1):
1516
"""
@@ -71,6 +72,10 @@ def __init__(self, f, pbounds, random_state=None, verbose=1):
7172
# Verbose
7273
self.verbose = verbose
7374

75+
# Event initialization
76+
events = [Events.INIT_DONE, Events.FIT_STEP_DONE, Events.FIT_DONE]
77+
super(BayesianOptimization, self).__init__(events)
78+
7479
def init(self, init_points):
7580
"""
7681
Initialization method to kick start the optimization process. It is a
@@ -100,6 +105,9 @@ def init(self, init_points):
100105
# Updates the flag
101106
self.initialized = True
102107

108+
# Notify about finished init method
109+
self.dispatch(Events.INIT_DONE)
110+
103111
def _observe_point(self, x):
104112
y = self.space.observe_point(x)
105113
if self.verbose:
@@ -303,10 +311,16 @@ def maximize(self,
303311
# Keep track of total number of iterations
304312
self.i += 1
305313

314+
# Notify about finished iteration
315+
self.dispatch(Events.FIT_STEP_DONE)
316+
306317
# Print a final report if verbose active.
307318
if self.verbose:
308319
self.plog.print_summary()
309320

321+
# Notify about finished optimization
322+
self.dispatch(Events.FIT_DONE)
323+
310324
def points_to_csv(self, file_name):
311325
"""
312326
After training all points for which we know target variable
@@ -353,3 +367,9 @@ def bounds(self):
353367
def dim(self):
354368
warnings.warn("use self.space.dim instead", DeprecationWarning)
355369
return self.space.dim
370+
371+
372+
class Events(object):
373+
INIT_DONE = 'initialized'
374+
FIT_STEP_DONE = 'fit_step_done'
375+
FIT_DONE = 'fit_done'

bayes_opt/observer.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Inspired/Taken from https://www.protechtraining.com/blog/post/879#simple-observer
2+
3+
4+
class Observer:
5+
def update(self, event, instance):
6+
# Avoid circular import
7+
from .bayesian_optimization import Events
8+
if event is Events.INIT_DONE:
9+
print("Initialization completed")
10+
elif event is Events.FIT_STEP_DONE:
11+
print("Optimization step finished, current max: ", instance.res['max'])
12+
elif event is Events.FIT_DONE:
13+
print("Optimization finished, maximum value at: ", instance.res['max'])
14+
15+
16+
class Observable(object):
17+
def __init__(self, events):
18+
# maps event names to subscribers
19+
# str -> dict
20+
self.events = {event: dict()
21+
for event in events}
22+
23+
def get_subscribers(self, event):
24+
return self.events[event]
25+
26+
def register(self, event, who, callback=None):
27+
if callback == None:
28+
callback = getattr(who, 'update')
29+
self.get_subscribers(event)[who] = callback
30+
31+
def unregister(self, event, who):
32+
del self.get_subscribers(event)[who]
33+
34+
def dispatch(self, event):
35+
for subscriber, callback in self.get_subscribers(event).items():
36+
callback(event, self)

examples/exploitation vs exploration.ipynb

Lines changed: 46 additions & 68 deletions
Large diffs are not rendered by default.

tests/test_helper_functions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from sklearn.gaussian_process.kernels import Matern
77
from bayes_opt.helpers import UtilityFunction, acq_max, ensure_rng
88

9-
109
def get_globals():
1110
X = np.array([
1211
[0.00, 0.00],

tests/test_observer.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import unittest
2+
from bayes_opt.observer import Observable
3+
4+
5+
class TestObserver():
6+
def __init__(self):
7+
self.counter = 0
8+
9+
def update(self, event, instance):
10+
self.counter += 1
11+
12+
class TestObserverPattern(unittest.TestCase):
13+
def setUp(self):
14+
events = ['a', 'b']
15+
self.observable = Observable(events)
16+
self.observer = TestObserver()
17+
18+
def test_get_subscribers(self):
19+
self.observable.register('a', self.observer)
20+
self.assertTrue(self.observer in self.observable.get_subscribers('a'))
21+
self.assertTrue(len(self.observable.get_subscribers('a').keys()) == 1)
22+
self.assertTrue(len(self.observable.get_subscribers('b').keys()) == 0)
23+
24+
def test_register(self):
25+
self.observable.register('a', self.observer)
26+
self.assertTrue(self.observer in self.observable.get_subscribers('a'))
27+
28+
def test_unregister(self):
29+
self.observable.register('a', self.observer)
30+
self.observable.unregister('a', self.observer)
31+
self.assertTrue(self.observer not in self.observable.get_subscribers('a'))
32+
33+
def test_dispatch(self):
34+
test_observer = TestObserver()
35+
self.observable.register('b', test_observer)
36+
self.observable.dispatch('b')
37+
self.observable.dispatch('b')
38+
39+
self.assertTrue(test_observer.counter == 2)
40+
41+
42+
if __name__ == '__main__':
43+
r"""
44+
CommandLine:
45+
python tests/test_observer.py
46+
"""
47+
# unittest.main()
48+
import pytest
49+
pytest.main([__file__])

0 commit comments

Comments
 (0)