Skip to content

Commit 3fa174f

Browse files
author
perdaug
committed
[MaxVar split, Part 2] Added the visualisation improvements.
1 parent 1d8900c commit 3fa174f

File tree

7 files changed

+312
-150
lines changed

7 files changed

+312
-150
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ Changelog
1010
- Improved performance when rerunning inference using stored data
1111
- Change SMC to use ModelPrior, use to immediately reject invalid proposals
1212
- Added the general Gaussian noise example model (fixed covariance)
13+
- Improved the interactive plotting (customised for the MaxVar-based acquisition methods)
14+
- Added a pair-wise plotting to plot_state() (a way to visualise n-dimensional parameters)
1315

1416
0.6.1 (2017-07-21)
1517
------------------

elfi/methods/bo/acquisition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ def __init__(self, *args, delta=None, **kwargs):
201201
kwargs['exploration_rate'] = 1 / delta
202202

203203
super(LCBSC, self).__init__(*args, **kwargs)
204+
self.name = 'lcbsc'
205+
self.label_fn = 'The Lower Confidence Bound Selection Criterion'
204206

205207
@property
206208
def delta(self):

elfi/methods/bo/gpy_regression.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,16 @@ def Y(self):
338338
"""Return output evidence."""
339339
return self._gp.Y
340340

341+
@property
342+
def noise(self):
343+
"""Return the noise."""
344+
return self._gp.Gaussian_noise.variance[0]
345+
346+
@property
347+
def instance(self):
348+
"""Return the gp instance."""
349+
return self._gp
350+
341351
def copy(self):
342352
"""Return a copy of current instance."""
343353
kopy = copy.copy(self)

elfi/methods/parameter_inference.py

Lines changed: 81 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
__all__ = ['Rejection', 'SMC', 'BayesianOptimization', 'BOLFI']
44

55
import logging
6+
from collections import OrderedDict
67
from math import ceil
78

8-
import matplotlib.pyplot as plt
99
import numpy as np
1010

1111
import elfi.client
@@ -89,7 +89,6 @@ def __init__(self,
8989
model = model.model if isinstance(model, NodeReference) else model
9090
if not model.parameter_names:
9191
raise ValueError('Model {} defines no parameters'.format(model))
92-
9392
self.model = model.copy()
9493
self.output_names = self._check_outputs(output_names)
9594

@@ -161,7 +160,7 @@ def extract_result(self):
161160
"""
162161
raise NotImplementedError
163162

164-
def update(self, batch, batch_index):
163+
def update(self, batch, batch_index, vis=None):
165164
"""Update the inference state with a new batch.
166165
167166
ELFI calls this method when a new batch has been computed and the state of
@@ -174,10 +173,8 @@ def update(self, batch, batch_index):
174173
dict with `self.outputs` as keys and the corresponding outputs for the batch
175174
as values
176175
batch_index : int
177-
178-
Returns
179-
-------
180-
None
176+
vis : bool, optional
177+
Interactive visualisation of the iterations.
181178
182179
"""
183180
self.state['n_batches'] += 1
@@ -231,7 +228,7 @@ def plot_state(self, **kwargs):
231228
"""
232229
raise NotImplementedError
233230

234-
def infer(self, *args, vis=None, **kwargs):
231+
def infer(self, *args, **opts):
235232
"""Set the objective and start the iterate loop until the inference is finished.
236233
237234
See the other arguments from the `set_objective` method.
@@ -241,23 +238,16 @@ def infer(self, *args, vis=None, **kwargs):
241238
result : Sample
242239
243240
"""
244-
vis_opt = vis if isinstance(vis, dict) else {}
245-
246-
self.set_objective(*args, **kwargs)
247-
241+
vis = opts.pop('vis', None)
242+
self.set_objective(*args, **opts)
248243
while not self.finished:
249-
self.iterate()
250-
if vis:
251-
self.plot_state(interactive=True, **vis_opt)
252-
244+
self.iterate(vis=vis)
253245
self.batches.cancel_pending()
254-
if vis:
255-
self.plot_state(close=True, **vis_opt)
256246

257247
return self.extract_result()
258248

259-
def iterate(self):
260-
"""Advance the inference by one iteration.
249+
def iterate(self, vis=None):
250+
"""Forward the inference one iteration.
261251
262252
This is a way to manually progress the inference. One iteration consists of
263253
waiting and processing the result of the next batch in succession and possibly
@@ -272,6 +262,11 @@ def iterate(self):
272262
will never be more batches submitted in parallel than the `max_parallel_batches`
273263
setting allows.
274264
265+
Parameters
266+
----------
267+
vis : bool, optional
268+
Interactive visualisation of the iterations.
269+
275270
Returns
276271
-------
277272
None
@@ -286,7 +281,7 @@ def iterate(self):
286281
# Handle the next ready batch in succession
287282
batch, batch_index = self.batches.wait_next()
288283
logger.debug('Received batch %d' % batch_index)
289-
self.update(batch, batch_index)
284+
self.update(batch, batch_index, vis=vis)
290285

291286
@property
292287
def finished(self):
@@ -466,17 +461,21 @@ def set_objective(self, n_samples, threshold=None, quantile=None, n_sim=None):
466461
# Reset the inference
467462
self.batches.reset()
468463

469-
def update(self, batch, batch_index):
464+
def update(self, batch, batch_index, vis=None):
470465
"""Update the inference state with a new batch.
471466
472467
Parameters
473468
----------
474469
batch : dict
475-
dict with `self.outputs` as keys and the corresponding outputs for the batch
476-
as values
470+
dict with `self.outputs` as keys and the corresponding outputs for the batch as values
471+
vis : bool, optional
472+
Interactive visualisation of the iterations.
477473
batch_index : int
478474
479475
"""
476+
if vis and self.state['samples'] is not None:
477+
self.plot_state(interactive=True, **vis)
478+
480479
super(Rejection, self).update(batch, batch_index)
481480
if self.state['samples'] is None:
482481
# Lazy initialization of the outputs dict
@@ -584,8 +583,8 @@ def plot_state(self, **options):
584583
displays = []
585584
if options.get('interactive'):
586585
from IPython import display
587-
displays.append(
588-
display.HTML('<span>Threshold: {}</span>'.format(self.state['threshold'])))
586+
html_display = '<span>Threshold: {}</span>'.format(self.state['threshold'])
587+
displays.append(display.HTML(html_display))
589588

590589
visin.plot_sample(
591590
self.state['samples'],
@@ -651,14 +650,15 @@ def extract_result(self):
651650
threshold=pop.threshold,
652651
**self._extract_result_kwargs())
653652

654-
def update(self, batch, batch_index):
653+
def update(self, batch, batch_index, vis=None):
655654
"""Update the inference state with a new batch.
656655
657656
Parameters
658657
----------
659658
batch : dict
660-
dict with `self.outputs` as keys and the corresponding outputs for the batch
661-
as values
659+
dict with `self.outputs` as keys and the corresponding outputs for the batch as values
660+
vis : bool, optional
661+
Interactive visualisation of the iterations.
662662
batch_index : int
663663
664664
"""
@@ -833,7 +833,6 @@ def __init__(self,
833833
output_names = [target_name] + model.parameter_names
834834
super(BayesianOptimization, self).__init__(
835835
model, output_names, batch_size=batch_size, **kwargs)
836-
837836
target_model = target_model or \
838837
GPyRegression(self.model.parameter_names, bounds=bounds)
839838

@@ -942,14 +941,16 @@ def extract_result(self):
942941
return OptimizationResult(
943942
x_min=batch_min, outputs=outputs, **self._extract_result_kwargs())
944943

945-
def update(self, batch, batch_index):
944+
def update(self, batch, batch_index, vis=None):
946945
"""Update the GP regression model of the target node with a new batch.
947946
948947
Parameters
949948
----------
950949
batch : dict
951950
dict with `self.outputs` as keys and the corresponding outputs for the batch
952951
as values
952+
vis : bool, optional
953+
Interactive visualisation of the iterations.
953954
batch_index : int
954955
955956
"""
@@ -958,12 +959,22 @@ def update(self, batch, batch_index):
958959

959960
params = batch_to_arr2d(batch, self.parameter_names)
960961
self._report_batch(batch_index, params, batch[self.target_name])
962+
# Adding the acquisition plots.
963+
if vis and self.batches.next_index * self.batch_size > self.n_initial_evidence:
964+
opts = {}
965+
opts['point_acq'] = {'x': params, 'd': batch[self.target_name]}
966+
opts['method_acq'] = self.acquisition_method.label_fn
967+
arr_ax = self.plot_state(interactive=True, **opts)
961968

962969
optimize = self._should_optimize()
963970
self.target_model.update(params, batch[self.target_name], optimize)
964971
if optimize:
965972
self.state['last_GP_update'] = self.target_model.n_evidence
966973

974+
# Adding the updated gp plots.
975+
if vis and self.batches.next_index * self.batch_size > self.n_initial_evidence:
976+
self.plot_state(interactive=True, arr_ax=arr_ax, **opts)
977+
967978
def prepare_new_batch(self, batch_index):
968979
"""Prepare values for a new batch.
969980
@@ -980,7 +991,6 @@ def prepare_new_batch(self, batch_index):
980991
981992
"""
982993
t = self._get_acquisition_index(batch_index)
983-
984994
# Check if we still should take initial points from the prior
985995
if t < 0:
986996
return
@@ -1040,60 +1050,40 @@ def _report_batch(self, batch_index, params, distances):
10401050
str += "{}{} at {}\n".format(fill, distances[i].item(), params[i])
10411051
logger.debug(str)
10421052

1043-
def plot_state(self, **options):
1044-
"""Plot the GP surface.
1045-
1046-
This feature is still experimental and currently supports only 2D cases.
1047-
"""
1048-
f = plt.gcf()
1049-
if len(f.axes) < 2:
1050-
f, _ = plt.subplots(1, 2, figsize=(13, 6), sharex='row', sharey='row')
1051-
1052-
gp = self.target_model
1053-
1054-
# Draw the GP surface
1055-
visin.draw_contour(
1056-
gp.predict_mean,
1057-
gp.bounds,
1058-
self.parameter_names,
1059-
title='GP target surface',
1060-
points=gp.X,
1061-
axes=f.axes[0],
1062-
**options)
1063-
1064-
# Draw the latest acquisitions
1065-
if options.get('interactive'):
1066-
point = gp.X[-1, :]
1067-
if len(gp.X) > 1:
1068-
f.axes[1].scatter(*point, color='red')
1053+
def plot_state(self, plot_acq_pairwise=False, arr_ax=None, **opts):
1054+
"""Plot the GP surface and the acquisition space.
10691055
1070-
displays = [gp._gp]
1056+
Notes
1057+
-----
1058+
- The plots of the GP surface and the acquisition space work for the
1059+
cases when dim < 3;
1060+
- The method is experimental.
10711061
1072-
if options.get('interactive'):
1073-
from IPython import display
1074-
displays.insert(
1075-
0,
1076-
display.HTML('<span><b>Iteration {}:</b> Acquired {} at {}</span>'.format(
1077-
len(gp.Y), gp.Y[-1][0], point)))
1078-
1079-
# Update
1080-
visin._update_interactive(displays, options)
1081-
1082-
def acq(x):
1083-
return self.acquisition_method.evaluate(x, len(gp.X))
1084-
1085-
# Draw the acquisition surface
1086-
visin.draw_contour(
1087-
acq,
1088-
gp.bounds,
1089-
self.parameter_names,
1090-
title='Acquisition surface',
1091-
points=None,
1092-
axes=f.axes[1],
1093-
**options)
1062+
Parameters
1063+
----------
1064+
plot_acq_pairwise : bool, optional
1065+
The option to plot the pair-wise acquisition point relationships.
10941066
1095-
if options.get('close'):
1096-
plt.close()
1067+
"""
1068+
if plot_acq_pairwise:
1069+
if len(self.parameter_names) == 1:
1070+
logger.info('Can not plot the pair-wise comparison for 1 parameter.')
1071+
return
1072+
# Transform the acquisition points in the acceptable format.
1073+
dict_pts_acq = OrderedDict()
1074+
for idx_param, name_param in enumerate(self.parameter_names):
1075+
dict_pts_acq[name_param] = self.target_model.X[:, idx_param]
1076+
vis.plot_pairs(dict_pts_acq, **opts)
1077+
else:
1078+
if len(self.parameter_names) == 1:
1079+
arr_ax = vis.plot_state_1d(self, arr_ax, **opts)
1080+
return arr_ax
1081+
elif len(self.parameter_names) == 2:
1082+
arr_ax = vis.plot_state_2d(self, arr_ax, **opts)
1083+
return arr_ax
1084+
else:
1085+
logger.info('The method is supported for 1- or 2-dimensions.')
1086+
return
10971087

10981088
def plot_discrepancy(self, axes=None, **kwargs):
10991089
"""Plot acquired parameters vs. resulting discrepancy.
@@ -1133,7 +1123,7 @@ class BOLFI(BayesianOptimization):
11331123
11341124
"""
11351125

1136-
def fit(self, n_evidence, threshold=None):
1126+
def fit(self, n_evidence, threshold=None, **opts):
11371127
"""Fit the surrogate model.
11381128
11391129
Generates a regression model for the discrepancy given the parameters.
@@ -1150,9 +1140,8 @@ def fit(self, n_evidence, threshold=None):
11501140

11511141
if n_evidence is None:
11521142
raise ValueError(
1153-
'You must specify the number of evidence (n_evidence) for the fitting')
1154-
1155-
self.infer(n_evidence)
1143+
'You must specify the number of evidence( n_evidence) for the fitting')
1144+
self.infer(n_evidence, **opts)
11561145
return self.extract_posterior(threshold)
11571146

11581147
def extract_posterior(self, threshold=None):
@@ -1235,12 +1224,10 @@ def sample(self,
12351224
else:
12361225
inds = np.argsort(self.target_model.Y[:, 0])
12371226
initials = np.asarray(self.target_model.X[inds])
1238-
12391227
self.target_model.is_sampling = True # enables caching for default RBF kernel
12401228

12411229
tasks_ids = []
12421230
ii_initial = 0
1243-
12441231
# sampling is embarrassingly parallel, so depending on self.client this may parallelize
12451232
for ii in range(n_chains):
12461233
seed = get_sub_seed(self.seed, ii)
@@ -1270,12 +1257,12 @@ def sample(self,
12701257

12711258
chains = np.asarray(chains)
12721259

1273-
print(
1274-
"{} chains of {} iterations acquired. Effective sample size and Rhat for each "
1275-
"parameter:".format(n_chains, n_samples))
1260+
logger.info(
1261+
"%d chains of %d iterations acquired. Effective sample size and Rhat for each "
1262+
"parameter:" % (n_chains, n_samples))
12761263
for ii, node in enumerate(self.parameter_names):
1277-
print(node, mcmc.eff_sample_size(chains[:, :, ii]),
1278-
mcmc.gelman_rubin(chains[:, :, ii]))
1264+
chain = chains[:, :, ii]
1265+
logger.info("%s %d %d" % (node, mcmc.eff_sample_size(chain), mcmc.gelman_rubin(chain)))
12791266

12801267
self.target_model.is_sampling = False
12811268

0 commit comments

Comments
 (0)