Skip to content

Commit 6402a49

Browse files
author
perdaug
committed
[MaxVar split, Part 2] Added the visualisation improvements.
1 parent 077308b commit 6402a49

File tree

6 files changed

+343
-145
lines changed

6 files changed

+343
-145
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Changelog
22
=========
33

4+
0.x
5+
---
6+
- Improved the interactive plotting (customised for the MaxVar-based acquisition methods)
7+
- Added a pair-wise plotting to plot_state() (a way to visualise n-dimensional parameters)
8+
49
0.6.3 (2017-09-28)
510
------------------
611

elfi/methods/bo/acquisition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ def __init__(self, *args, delta=None, **kwargs):
201201
kwargs['exploration_rate'] = 1 / delta
202202

203203
super(LCBSC, self).__init__(*args, **kwargs)
204+
self.name = 'lcbsc'
205+
self.label_fn = 'The Lower Confidence Bound Selection Criterion'
204206

205207
@property
206208
def delta(self):

elfi/methods/parameter_inference.py

Lines changed: 86 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
__all__ = ['Rejection', 'SMC', 'BayesianOptimization', 'BOLFI']
44

55
import logging
6+
from collections import OrderedDict
67
from math import ceil
78

8-
import matplotlib.pyplot as plt
99
import numpy as np
1010

1111
import elfi.client
@@ -89,7 +89,6 @@ def __init__(self,
8989
model = model.model if isinstance(model, NodeReference) else model
9090
if not model.parameter_names:
9191
raise ValueError('Model {} defines no parameters'.format(model))
92-
9392
self.model = model.copy()
9493
self.output_names = self._check_outputs(output_names)
9594

@@ -161,7 +160,7 @@ def extract_result(self):
161160
"""
162161
raise NotImplementedError
163162

164-
def update(self, batch, batch_index):
163+
def update(self, batch, batch_index, vis=None):
165164
"""Update the inference state with a new batch.
166165
167166
ELFI calls this method when a new batch has been computed and the state of
@@ -174,10 +173,8 @@ def update(self, batch, batch_index):
174173
dict with `self.outputs` as keys and the corresponding outputs for the batch
175174
as values
176175
batch_index : int
177-
178-
Returns
179-
-------
180-
None
176+
vis : bool, optional
177+
Interactive visualisation of the iterations.
181178
182179
"""
183180
self.state['n_batches'] += 1
@@ -231,7 +228,7 @@ def plot_state(self, **kwargs):
231228
"""
232229
raise NotImplementedError
233230

234-
def infer(self, *args, vis=None, **kwargs):
231+
def infer(self, *args, **options):
235232
"""Set the objective and start the iterate loop until the inference is finished.
236233
237234
See the other arguments from the `set_objective` method.
@@ -241,23 +238,16 @@ def infer(self, *args, vis=None, **kwargs):
241238
result : Sample
242239
243240
"""
244-
vis_opt = vis if isinstance(vis, dict) else {}
245-
246-
self.set_objective(*args, **kwargs)
247-
241+
vis = options.pop('vis', None)
242+
self.set_objective(*args, **options)
248243
while not self.finished:
249-
self.iterate()
250-
if vis:
251-
self.plot_state(interactive=True, **vis_opt)
252-
244+
self.iterate(vis=vis)
253245
self.batches.cancel_pending()
254-
if vis:
255-
self.plot_state(close=True, **vis_opt)
256246

257247
return self.extract_result()
258248

259-
def iterate(self):
260-
"""Advance the inference by one iteration.
249+
def iterate(self, vis=None):
250+
"""Forward the inference one iteration.
261251
262252
This is a way to manually progress the inference. One iteration consists of
263253
waiting and processing the result of the next batch in succession and possibly
@@ -272,6 +262,11 @@ def iterate(self):
272262
will never be more batches submitted in parallel than the `max_parallel_batches`
273263
setting allows.
274264
265+
Parameters
266+
----------
267+
vis : bool, optional
268+
Interactive visualisation of the iterations.
269+
275270
Returns
276271
-------
277272
None
@@ -286,7 +281,7 @@ def iterate(self):
286281
# Handle the next ready batch in succession
287282
batch, batch_index = self.batches.wait_next()
288283
logger.debug('Received batch %d' % batch_index)
289-
self.update(batch, batch_index)
284+
self.update(batch, batch_index, vis=vis)
290285

291286
@property
292287
def finished(self):
@@ -466,17 +461,21 @@ def set_objective(self, n_samples, threshold=None, quantile=None, n_sim=None):
466461
# Reset the inference
467462
self.batches.reset()
468463

469-
def update(self, batch, batch_index):
464+
def update(self, batch, batch_index, vis=None):
470465
"""Update the inference state with a new batch.
471466
472467
Parameters
473468
----------
474469
batch : dict
475-
dict with `self.outputs` as keys and the corresponding outputs for the batch
476-
as values
470+
dict with `self.outputs` as keys and the corresponding outputs for the batch as values
471+
vis : bool, optional
472+
Interactive visualisation of the iterations.
477473
batch_index : int
478474
479475
"""
476+
if vis and self.state['samples'] is not None:
477+
self.plot_state(interactive=True, **vis)
478+
480479
super(Rejection, self).update(batch, batch_index)
481480
if self.state['samples'] is None:
482481
# Lazy initialization of the outputs dict
@@ -584,8 +583,8 @@ def plot_state(self, **options):
584583
displays = []
585584
if options.get('interactive'):
586585
from IPython import display
587-
displays.append(
588-
display.HTML('<span>Threshold: {}</span>'.format(self.state['threshold'])))
586+
html_display = '<span>Threshold: {}</span>'.format(self.state['threshold'])
587+
displays.append(display.HTML(html_display))
589588

590589
visin.plot_sample(
591590
self.state['samples'],
@@ -651,14 +650,15 @@ def extract_result(self):
651650
threshold=pop.threshold,
652651
**self._extract_result_kwargs())
653652

654-
def update(self, batch, batch_index):
653+
def update(self, batch, batch_index, vis=None):
655654
"""Update the inference state with a new batch.
656655
657656
Parameters
658657
----------
659658
batch : dict
660-
dict with `self.outputs` as keys and the corresponding outputs for the batch
661-
as values
659+
dict with `self.outputs` as keys and the corresponding outputs for the batch as values
660+
vis : bool, optional
661+
Interactive visualisation of the iterations.
662662
batch_index : int
663663
664664
"""
@@ -942,14 +942,16 @@ def extract_result(self):
942942
return OptimizationResult(
943943
x_min=batch_min, outputs=outputs, **self._extract_result_kwargs())
944944

945-
def update(self, batch, batch_index):
945+
def update(self, batch, batch_index, vis=None):
946946
"""Update the GP regression model of the target node with a new batch.
947947
948948
Parameters
949949
----------
950950
batch : dict
951951
dict with `self.outputs` as keys and the corresponding outputs for the batch
952952
as values
953+
vis : bool, optional
954+
Interactive visualisation of the iterations.
953955
batch_index : int
954956
955957
"""
@@ -959,11 +961,22 @@ def update(self, batch, batch_index):
959961
params = batch_to_arr2d(batch, self.parameter_names)
960962
self._report_batch(batch_index, params, batch[self.target_name])
961963

964+
# Adding the acquisition plots.
965+
if vis and self.batches.next_index * self.batch_size > self.n_initial_evidence:
966+
options = {}
967+
options['point_acq'] = {'x': params, 'd': batch[self.target_name]}
968+
options['method_acq'] = self.acquisition_method.label_fn
969+
arr_ax = self.plot_state(interactive=True, **options)
970+
962971
optimize = self._should_optimize()
963972
self.target_model.update(params, batch[self.target_name], optimize)
964973
if optimize:
965974
self.state['last_GP_update'] = self.target_model.n_evidence
966975

976+
# Adding the updated gp plots.
977+
if vis and self.batches.next_index * self.batch_size > self.n_initial_evidence:
978+
self.plot_state(interactive=True, arr_ax=arr_ax, **options)
979+
967980
def prepare_new_batch(self, batch_index):
968981
"""Prepare values for a new batch.
969982
@@ -1040,60 +1053,51 @@ def _report_batch(self, batch_index, params, distances):
10401053
str += "{}{} at {}\n".format(fill, distances[i].item(), params[i])
10411054
logger.debug(str)
10421055

1043-
def plot_state(self, **options):
1044-
"""Plot the GP surface.
1056+
def plot_state(self, plot_acq_pairwise=False, arr_ax=None, **options):
1057+
"""Plot the GP surface and the acquisition space.
10451058
1046-
This feature is still experimental and currently supports only 2D cases.
1047-
"""
1048-
f = plt.gcf()
1049-
if len(f.axes) < 2:
1050-
f, _ = plt.subplots(1, 2, figsize=(13, 6), sharex='row', sharey='row')
1051-
1052-
gp = self.target_model
1053-
1054-
# Draw the GP surface
1055-
visin.draw_contour(
1056-
gp.predict_mean,
1057-
gp.bounds,
1058-
self.parameter_names,
1059-
title='GP target surface',
1060-
points=gp.X,
1061-
axes=f.axes[0],
1062-
**options)
1059+
Notes
1060+
-----
1061+
- The plots of the GP surface and the acquisition space work for the
1062+
cases when dim < 3;
1063+
- The method is experimental.
10631064
1064-
# Draw the latest acquisitions
1065-
if options.get('interactive'):
1066-
point = gp.X[-1, :]
1067-
if len(gp.X) > 1:
1068-
f.axes[1].scatter(*point, color='red')
1065+
Parameters
1066+
----------
1067+
plot_acq_pairwise : bool, optional
1068+
The option to plot the pair-wise acquisition point relationships.
1069+
arr_ax : array_like, optional
1070+
Handled implicitly upon interactive visualisation.
10691071
1070-
displays = [gp._gp]
1072+
Returns
1073+
-------
1074+
array_like
1075+
Axes for interactive visualisation.
10711076
1072-
if options.get('interactive'):
1073-
from IPython import display
1074-
displays.insert(
1075-
0,
1076-
display.HTML('<span><b>Iteration {}:</b> Acquired {} at {}</span>'.format(
1077-
len(gp.Y), gp.Y[-1][0], point)))
1078-
1079-
# Update
1080-
visin._update_interactive(displays, options)
1081-
1082-
def acq(x):
1083-
return self.acquisition_method.evaluate(x, len(gp.X))
1084-
1085-
# Draw the acquisition surface
1086-
visin.draw_contour(
1087-
acq,
1088-
gp.bounds,
1089-
self.parameter_names,
1090-
title='Acquisition surface',
1091-
points=None,
1092-
axes=f.axes[1],
1093-
**options)
1077+
Raises
1078+
------
1079+
ValueError
1080+
Unsupported dimension.
10941081
1095-
if options.get('close'):
1096-
plt.close()
1082+
"""
1083+
if plot_acq_pairwise:
1084+
if len(self.parameter_names) == 1:
1085+
raise ValueError('Can not plot the pair-wise comparison for 1 parameter.')
1086+
1087+
# Transform the acquisition points in the accepted format.
1088+
dict_pts_acq = OrderedDict()
1089+
for idx_param, name_param in enumerate(self.parameter_names):
1090+
dict_pts_acq[name_param] = self.target_model.X[:, idx_param]
1091+
vis.plot_pairs(dict_pts_acq, **options)
1092+
else:
1093+
if len(self.parameter_names) == 1:
1094+
arr_ax = vis.plot_state_1d(self, arr_ax, **options)
1095+
return arr_ax
1096+
elif len(self.parameter_names) == 2:
1097+
arr_ax = vis.plot_state_2d(self, arr_ax, **options)
1098+
return arr_ax
1099+
else:
1100+
raise ValueError('The method is supported only for 1- or 2-dimensions.')
10971101

10981102
def plot_discrepancy(self, axes=None, **kwargs):
10991103
"""Plot acquired parameters vs. resulting discrepancy.
@@ -1133,7 +1137,7 @@ class BOLFI(BayesianOptimization):
11331137
11341138
"""
11351139

1136-
def fit(self, n_evidence, threshold=None):
1140+
def fit(self, n_evidence, threshold=None, **options):
11371141
"""Fit the surrogate model.
11381142
11391143
Generates a regression model for the discrepancy given the parameters.
@@ -1150,9 +1154,8 @@ def fit(self, n_evidence, threshold=None):
11501154

11511155
if n_evidence is None:
11521156
raise ValueError(
1153-
'You must specify the number of evidence (n_evidence) for the fitting')
1154-
1155-
self.infer(n_evidence)
1157+
'You must specify the number of evidence( n_evidence) for the fitting')
1158+
self.infer(n_evidence, **options)
11561159
return self.extract_posterior(threshold)
11571160

11581161
def extract_posterior(self, threshold=None):

0 commit comments

Comments
 (0)