Skip to content

Commit c7accf5

Browse files
fmfnfmfn
authored andcommitted
Introduce default printer and json observers
This commit refactors heavily how loggin is done in the package. It removes the old style of loggin and does everything now using observers. By default the user has access to two useful observers, json and printer, and if the users doesn't want to deal with that the verbose parameter takes care of the basics.
1 parent 58245b0 commit c7accf5

File tree

8 files changed

+286
-151
lines changed

8 files changed

+286
-151
lines changed

bayes_opt/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .bayesian_optimization import BayesianOptimization, Events
2-
from .helpers import UtilityFunction
3-
from .observer import Observer
2+
from .util import UtilityFunction
3+
from .observer import ScreenLogger
44

5-
__all__ = ["BayesianOptimization", "UtilityFunction", "Events", "Observer"]
5+
__all__ = ["BayesianOptimization", "UtilityFunction", "Events", "ScreenLogger"]

bayes_opt/bayesian_optimization.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import numpy as np
33

44
from .target_space import TargetSpace
5-
from .observer import Observable, Events
6-
from .helpers import UtilityFunction, acq_max, ensure_rng
5+
from .event import Events, DEFAULT_EVENTS
6+
from .observer import _get_default_logger
7+
from .util import UtilityFunction, acq_max, ensure_rng
78

89
from sklearn.gaussian_process.kernels import Matern
910
from sklearn.gaussian_process import GaussianProcessRegressor
@@ -22,7 +23,7 @@ def __len__(self):
2223

2324
def __next__(self):
2425
if self.empty:
25-
raise ValueError("Cannot retrieve next object from empty queue.")
26+
raise StopIteration("Queue is empty, no more objects to retrieve.")
2627
obj = self._queue[0]
2728
self._queue = self._queue[1:]
2829
return obj
@@ -32,6 +33,33 @@ def add(self, obj):
3233
self._queue.append(obj)
3334

3435

36+
class Observable:
37+
"""
38+
39+
Inspired/Taken from
40+
https://www.protechtraining.com/blog/post/879#simple-observer
41+
"""
42+
def __init__(self, events):
43+
# maps event names to subscribers
44+
# str -> dict
45+
self._events = {event: dict() for event in events}
46+
47+
def get_subscribers(self, event):
48+
return self._events[event]
49+
50+
def subscribe(self, event, subscriber, callback=None):
51+
if callback == None:
52+
callback = getattr(subscriber, 'update')
53+
self.get_subscribers(event)[subscriber] = callback
54+
55+
def unsubscribe(self, event, subscriber):
56+
del self.get_subscribers(event)[subscriber]
57+
58+
def dispatch(self, event):
59+
for _, callback in self.get_subscribers(event).items():
60+
callback(event, self)
61+
62+
3563
class BayesianOptimization(Observable):
3664
def __init__(self, f, pbounds, random_state=None, verbose=1):
3765
""""""
@@ -53,7 +81,8 @@ def __init__(self, f, pbounds, random_state=None, verbose=1):
5381
random_state=self._random_state,
5482
)
5583

56-
super(BayesianOptimization, self).__init__(events=None)
84+
self._verbose = verbose
85+
super(BayesianOptimization, self).__init__(events=DEFAULT_EVENTS)
5786

5887
@property
5988
def space(self):
@@ -80,6 +109,9 @@ def probe(self, x, lazy=True):
80109

81110
def suggest(self, utility_function):
82111
"""Most promissing point to probe next"""
112+
if len(self._space) == 0:
113+
return self._space.array_to_params(self._space.random_sample())
114+
83115
# Sklearn's GP throws a large number of warnings at times, but
84116
# we don't really need to see them here.
85117
with warnings.catch_warnings():
@@ -105,6 +137,13 @@ def _prime_queue(self, init_points):
105137
for _ in range(init_points):
106138
self._queue.add(self._space.random_sample())
107139

140+
def _prime_subscriptions(self):
141+
if not any([len(subs) for subs in self._events.values()]):
142+
_logger = _get_default_logger(self._verbose)
143+
self.subscribe(Events.OPTMIZATION_START, _logger)
144+
self.subscribe(Events.OPTMIZATION_STEP, _logger)
145+
self.subscribe(Events.OPTMIZATION_END, _logger)
146+
108147
def maximize(self,
109148
init_points: int=5,
110149
n_iter: int=25,
@@ -113,21 +152,24 @@ def maximize(self,
113152
xi: float=0.0,
114153
**gp_params):
115154
"""Mazimize your function"""
155+
self._prime_subscriptions()
156+
self.dispatch(Events.OPTMIZATION_START)
116157
self._prime_queue(init_points)
117158

118159
util = UtilityFunction(kind=acq, kappa=kappa, xi=xi)
119160
iteration = 0
120161
while not self._queue.empty or iteration < n_iter:
121162
try:
122163
x_probe = next(self._queue)
123-
except ValueError:
164+
except StopIteration:
124165
x_probe = self.suggest(util)
125166
iteration += 1
126167

127168
self.probe(x_probe, lazy=False)
169+
self.dispatch(Events.OPTMIZATION_STEP)
128170

129171
# Notify about finished optimization
130-
self.dispatch(Events.FIT_DONE)
172+
self.dispatch(Events.OPTMIZATION_END)
131173

132174
def set_bounds(self, new_bounds):
133175
"""

bayes_opt/event.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
class Events:
2+
OPTMIZATION_START = 'optmization:start'
3+
OPTMIZATION_STEP = 'optmization:step'
4+
OPTMIZATION_END = 'optmization:end'
5+
6+
PROBE_FROM_SUGGESTION = "probe:suggestion"
7+
PROBE_FROM_QUEUE = "probe:queue"
8+
9+
ELEMENT_ADDED_TO_QUEUE = ""
10+
QUEUE_IS_EMPTY = ""
11+
12+
13+
DEFAULT_EVENTS = [
14+
Events.OPTMIZATION_START,
15+
Events.OPTMIZATION_STEP,
16+
Events.OPTMIZATION_END,
17+
Events.ELEMENT_ADDED_TO_QUEUE,
18+
Events.QUEUE_IS_EMPTY,
19+
]

bayes_opt/observer.py

Lines changed: 146 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,163 @@
11
"""
2+
observers...
3+
"""
4+
import os
5+
import json
6+
from datetime import datetime
27

8+
from .event import Events
9+
from .util import Colours
10+
11+
12+
class _Tracker:
13+
def __init__(self):
14+
self._iterations = 0
15+
16+
self._previous_max = None
17+
self._previous_max_params = None
18+
19+
self._start_time = None
20+
self._previous_time = None
21+
22+
def _update_tracker(self, event, instance):
23+
if event == Events.OPTMIZATION_STEP:
24+
self._iterations += 1
25+
26+
current_max = instance.max
27+
if (self._previous_max is None or
28+
current_max["target"] > self._previous_max):
29+
self._previous_max = current_max["target"]
30+
self._previous_max_params = current_max["params"]
31+
32+
def _time_metrics(self):
33+
now = datetime.now()
34+
if self._start_time is None:
35+
self._start_time = now
36+
if self._previous_time is None:
37+
self._previous_time = now
38+
39+
time_elapsed = now - self._start_time
40+
time_delta = now - self._previous_time
41+
42+
self._previous_time = now
43+
return (
44+
now.strftime("%Y-%m-%d %H:%M:%S"),
45+
time_elapsed.total_seconds(),
46+
time_delta.total_seconds()
47+
)
348

4-
Inspired/Taken from https://www.protechtraining.com/blog/post/879#simple-observer
5-
"""
649

50+
def _get_default_logger(verbose):
51+
return ScreenLogger(verbose=verbose)
752

8-
class Events(object):
9-
INIT_DONE = 'initialized'
10-
FIT_STEP_DONE = 'fit_step_done'
11-
FIT_DONE = 'fit_done'
1253

54+
class ScreenLogger(_Tracker):
55+
_default_cell_size = 9
56+
_default_precision = 4
1357

14-
DEFAULT_EVENTS = [
15-
Events.INIT_DONE,
16-
Events.FIT_STEP_DONE,
17-
Events.FIT_DONE
18-
]
58+
def __init__(self, verbose=0):
59+
self._verbose = verbose
60+
self._header_length = None
61+
super(ScreenLogger, self).__init__()
1962

63+
@property
64+
def verbose(self):
65+
return self._verbose
2066

21-
class Observable(object):
22-
def __init__(self, events=None):
23-
# maps event names to subscribers
24-
# str -> dict
25-
if events is None:
26-
events = DEFAULT_EVENTS
67+
@verbose.setter
68+
def verbose(self, v):
69+
self._verbose = v
2770

28-
self.events = {event: dict() for event in events}
71+
def _format_number(self, x):
72+
if isinstance(x, int):
73+
s = "{x:< {s}}".format(
74+
x=x,
75+
s=self._default_cell_size,
76+
)
77+
else:
78+
s = "{x:< {s}.{p}}".format(
79+
x=x,
80+
s=self._default_cell_size,
81+
p=self._default_precision,
82+
)
2983

30-
def get_subscribers(self, event):
31-
return self.events[event]
84+
if len(s) > self._default_cell_size:
85+
if "." in s:
86+
return s[:self._default_cell_size]
87+
else:
88+
return s[:self._default_cell_size - 3] + "..."
89+
return s
3290

33-
def register(self, event, who, callback=None):
34-
if callback == None:
35-
callback = getattr(who, 'update')
36-
self.get_subscribers(event)[who] = callback
91+
def _format_key(self, key):
92+
s = "{key:^{s}}".format(
93+
key=key,
94+
s=self._default_cell_size
95+
)
96+
if len(s) > self._default_cell_size:
97+
return s[:self._default_cell_size - 3] + "..."
98+
return s
3799

38-
def unregister(self, event, who):
39-
del self.get_subscribers(event)[who]
100+
def _step(self, instance, colour=Colours.black):
101+
res = instance.res[-1]
102+
cells = []
40103

41-
def dispatch(self, event):
42-
for subscriber, callback in self.get_subscribers(event).items():
43-
callback(event, self)
104+
cells.append(self._format_number(self._iterations))
105+
cells.append(self._format_number(res["target"]))
44106

107+
for val in res["params"].values():
108+
cells.append(self._format_number(val))
109+
110+
return "| " + " | ".join(map(colour, cells)) + " |"
111+
112+
def _header(self, instance):
113+
cells = []
114+
cells.append(self._format_key("iter"))
115+
cells.append(self._format_key("target"))
116+
for key in instance.space.keys:
117+
cells.append(self._format_key(key))
118+
119+
line = "| " + " | ".join(cells) + " |"
120+
self._header_length = len(line)
121+
return line + "\n" + ("-" * self._header_length)
122+
123+
def update(self, event, instance):
124+
if event == Events.OPTMIZATION_START:
125+
line = self._header(instance)
126+
elif event == Events.OPTMIZATION_STEP:
127+
colour = (
128+
Colours.purple if
129+
self._previous_max is None or
130+
instance.max["target"] > self._previous_max else
131+
Colours.black
132+
)
133+
line = self._step(instance, colour=colour)
134+
elif event == Events.OPTMIZATION_END:
135+
line = "=" * self._header_length
136+
137+
print(line)
138+
self._update_tracker(event, instance)
139+
140+
class JSONLogger(_Tracker):
141+
def __init__(self, path):
142+
self._path = path if path[-5:] == ".json" else path + ".json"
143+
try:
144+
os.remove(self._path)
145+
except OSError:
146+
pass
147+
super(JSONLogger, self).__init__()
45148

46-
class Observer:
47149
def update(self, event, instance):
48-
# Avoid circular import
49-
from .bayesian_optimization import Events
50-
if event is Events.INIT_DONE:
51-
print("Initialization completed")
52-
elif event is Events.FIT_STEP_DONE:
53-
print("Optimization step finished, current max: ",
54-
instance.res['max'])
55-
elif event is Events.FIT_DONE:
56-
print("Optimization finished, maximum value at: ",
57-
instance.res['max'])
150+
if event == Events.OPTMIZATION_STEP:
151+
data = dict(instance.res[-1])
152+
153+
now, time_elapsed, time_delta = self._time_metrics()
154+
data["datetime"] = {
155+
"datetime": now,
156+
"elapsed": time_elapsed,
157+
"delta": time_delta,
158+
}
159+
160+
with open(self._path, "a") as f:
161+
f.write(json.dumps(data) + "\n")
162+
163+
self._update_tracker(event, instance)

bayes_opt/target_space.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from __future__ import print_function, division
21
import numpy as np
3-
from .helpers import ensure_rng, unique_rows
2+
from .util import ensure_rng, unique_rows
43

54

65
def _hashable(x):
@@ -226,7 +225,7 @@ def res(self):
226225
params = [dict(zip(self.keys, p)) for p in self.x]
227226

228227
return [
229-
{"target": target, "param": param}
228+
{"target": target, "params": param}
230229
for target, param in zip(self.target, params)
231230
]
232231

0 commit comments

Comments
 (0)