33__all__ = ['Rejection' , 'SMC' , 'BayesianOptimization' , 'BOLFI' ]
44
55import logging
6+ from collections import OrderedDict
67from math import ceil
78
8- import matplotlib .pyplot as plt
99import numpy as np
1010
1111import elfi .client
@@ -89,7 +89,6 @@ def __init__(self,
8989 model = model .model if isinstance (model , NodeReference ) else model
9090 if not model .parameter_names :
9191 raise ValueError ('Model {} defines no parameters' .format (model ))
92-
9392 self .model = model .copy ()
9493 self .output_names = self ._check_outputs (output_names )
9594
@@ -161,7 +160,7 @@ def extract_result(self):
161160 """
162161 raise NotImplementedError
163162
164- def update (self , batch , batch_index ):
163+ def update (self , batch , batch_index , vis = None ):
165164 """Update the inference state with a new batch.
166165
167166 ELFI calls this method when a new batch has been computed and the state of
@@ -174,10 +173,8 @@ def update(self, batch, batch_index):
174173 dict with `self.outputs` as keys and the corresponding outputs for the batch
175174 as values
176175 batch_index : int
177-
178- Returns
179- -------
180- None
176+ vis : bool, optional
177+ Interactive visualisation of the iterations.
181178
182179 """
183180 self .state ['n_batches' ] += 1
@@ -231,7 +228,7 @@ def plot_state(self, **kwargs):
231228 """
232229 raise NotImplementedError
233230
234- def infer (self , * args , vis = None , ** kwargs ):
231+ def infer (self , * args , ** opts ):
235232 """Set the objective and start the iterate loop until the inference is finished.
236233
237234 See the other arguments from the `set_objective` method.
@@ -241,23 +238,16 @@ def infer(self, *args, vis=None, **kwargs):
241238 result : Sample
242239
243240 """
244- vis_opt = vis if isinstance (vis , dict ) else {}
245-
246- self .set_objective (* args , ** kwargs )
247-
241+ vis = opts .pop ('vis' , None )
242+ self .set_objective (* args , ** opts )
248243 while not self .finished :
249- self .iterate ()
250- if vis :
251- self .plot_state (interactive = True , ** vis_opt )
252-
244+ self .iterate (vis = vis )
253245 self .batches .cancel_pending ()
254- if vis :
255- self .plot_state (close = True , ** vis_opt )
256246
257247 return self .extract_result ()
258248
259- def iterate (self ):
260- """Advance the inference by one iteration.
249+ def iterate (self , vis = None ):
250+ """Forward the inference one iteration.
261251
262252 This is a way to manually progress the inference. One iteration consists of
263253 waiting and processing the result of the next batch in succession and possibly
@@ -272,6 +262,11 @@ def iterate(self):
272262 will never be more batches submitted in parallel than the `max_parallel_batches`
273263 setting allows.
274264
265+ Parameters
266+ ----------
267+ vis : bool, optional
268+ Interactive visualisation of the iterations.
269+
275270 Returns
276271 -------
277272 None
@@ -286,7 +281,7 @@ def iterate(self):
286281 # Handle the next ready batch in succession
287282 batch , batch_index = self .batches .wait_next ()
288283 logger .debug ('Received batch %d' % batch_index )
289- self .update (batch , batch_index )
284+ self .update (batch , batch_index , vis = vis )
290285
291286 @property
292287 def finished (self ):
@@ -466,17 +461,21 @@ def set_objective(self, n_samples, threshold=None, quantile=None, n_sim=None):
466461 # Reset the inference
467462 self .batches .reset ()
468463
469- def update (self , batch , batch_index ):
464+ def update (self , batch , batch_index , vis = None ):
470465 """Update the inference state with a new batch.
471466
472467 Parameters
473468 ----------
474469 batch : dict
475- dict with `self.outputs` as keys and the corresponding outputs for the batch
476- as values
470+ dict with `self.outputs` as keys and the corresponding outputs for the batch as values
471+ vis : bool, optional
472+ Interactive visualisation of the iterations.
477473 batch_index : int
478474
479475 """
476+ if vis and self .state ['samples' ] is not None :
477+ self .plot_state (interactive = True , ** vis )
478+
480479 super (Rejection , self ).update (batch , batch_index )
481480 if self .state ['samples' ] is None :
482481 # Lazy initialization of the outputs dict
@@ -584,8 +583,8 @@ def plot_state(self, **options):
584583 displays = []
585584 if options .get ('interactive' ):
586585 from IPython import display
587- displays . append (
588- display .HTML ('<span>Threshold: {}</span>' . format ( self . state [ 'threshold' ]) ))
586+ html_display = '<span>Threshold: {}</span>' . format ( self . state [ 'threshold' ])
587+ displays . append ( display .HTML (html_display ))
589588
590589 visin .plot_sample (
591590 self .state ['samples' ],
@@ -651,14 +650,15 @@ def extract_result(self):
651650 threshold = pop .threshold ,
652651 ** self ._extract_result_kwargs ())
653652
654- def update (self , batch , batch_index ):
653+ def update (self , batch , batch_index , vis = None ):
655654 """Update the inference state with a new batch.
656655
657656 Parameters
658657 ----------
659658 batch : dict
660- dict with `self.outputs` as keys and the corresponding outputs for the batch
661- as values
659+ dict with `self.outputs` as keys and the corresponding outputs for the batch as values
660+ vis : bool, optional
661+ Interactive visualisation of the iterations.
662662 batch_index : int
663663
664664 """
@@ -833,7 +833,6 @@ def __init__(self,
833833 output_names = [target_name ] + model .parameter_names
834834 super (BayesianOptimization , self ).__init__ (
835835 model , output_names , batch_size = batch_size , ** kwargs )
836-
837836 target_model = target_model or \
838837 GPyRegression (self .model .parameter_names , bounds = bounds )
839838
@@ -942,14 +941,16 @@ def extract_result(self):
942941 return OptimizationResult (
943942 x_min = batch_min , outputs = outputs , ** self ._extract_result_kwargs ())
944943
945- def update (self , batch , batch_index ):
944+ def update (self , batch , batch_index , vis = None ):
946945 """Update the GP regression model of the target node with a new batch.
947946
948947 Parameters
949948 ----------
950949 batch : dict
951950 dict with `self.outputs` as keys and the corresponding outputs for the batch
952951 as values
952+ vis : bool, optional
953+ Interactive visualisation of the iterations.
953954 batch_index : int
954955
955956 """
@@ -958,12 +959,22 @@ def update(self, batch, batch_index):
958959
959960 params = batch_to_arr2d (batch , self .parameter_names )
960961 self ._report_batch (batch_index , params , batch [self .target_name ])
962+ # Adding the acquisition plots.
963+ if vis and self .batches .next_index * self .batch_size > self .n_initial_evidence :
964+ opts = {}
965+ opts ['point_acq' ] = {'x' : params , 'd' : batch [self .target_name ]}
966+ opts ['method_acq' ] = self .acquisition_method .label_fn
967+ arr_ax = self .plot_state (interactive = True , ** opts )
961968
962969 optimize = self ._should_optimize ()
963970 self .target_model .update (params , batch [self .target_name ], optimize )
964971 if optimize :
965972 self .state ['last_GP_update' ] = self .target_model .n_evidence
966973
974+ # Adding the updated gp plots.
975+ if vis and self .batches .next_index * self .batch_size > self .n_initial_evidence :
976+ self .plot_state (interactive = True , arr_ax = arr_ax , ** opts )
977+
967978 def prepare_new_batch (self , batch_index ):
968979 """Prepare values for a new batch.
969980
@@ -980,7 +991,6 @@ def prepare_new_batch(self, batch_index):
980991
981992 """
982993 t = self ._get_acquisition_index (batch_index )
983-
984994 # Check if we still should take initial points from the prior
985995 if t < 0 :
986996 return
@@ -1040,60 +1050,40 @@ def _report_batch(self, batch_index, params, distances):
10401050 str += "{}{} at {}\n " .format (fill , distances [i ].item (), params [i ])
10411051 logger .debug (str )
10421052
1043- def plot_state (self , ** options ):
1044- """Plot the GP surface.
1045-
1046- This feature is still experimental and currently supports only 2D cases.
1047- """
1048- f = plt .gcf ()
1049- if len (f .axes ) < 2 :
1050- f , _ = plt .subplots (1 , 2 , figsize = (13 , 6 ), sharex = 'row' , sharey = 'row' )
1051-
1052- gp = self .target_model
1053-
1054- # Draw the GP surface
1055- visin .draw_contour (
1056- gp .predict_mean ,
1057- gp .bounds ,
1058- self .parameter_names ,
1059- title = 'GP target surface' ,
1060- points = gp .X ,
1061- axes = f .axes [0 ],
1062- ** options )
1063-
1064- # Draw the latest acquisitions
1065- if options .get ('interactive' ):
1066- point = gp .X [- 1 , :]
1067- if len (gp .X ) > 1 :
1068- f .axes [1 ].scatter (* point , color = 'red' )
1053+ def plot_state (self , plot_acq_pairwise = False , arr_ax = None , ** opts ):
1054+ """Plot the GP surface and the acquisition space.
10691055
1070- displays = [gp ._gp ]
1056+ Notes
1057+ -----
1058+ - The plots of the GP surface and the acquisition space work for the
1059+ cases when dim < 3;
1060+ - The method is experimental.
10711061
1072- if options .get ('interactive' ):
1073- from IPython import display
1074- displays .insert (
1075- 0 ,
1076- display .HTML ('<span><b>Iteration {}:</b> Acquired {} at {}</span>' .format (
1077- len (gp .Y ), gp .Y [- 1 ][0 ], point )))
1078-
1079- # Update
1080- visin ._update_interactive (displays , options )
1081-
1082- def acq (x ):
1083- return self .acquisition_method .evaluate (x , len (gp .X ))
1084-
1085- # Draw the acquisition surface
1086- visin .draw_contour (
1087- acq ,
1088- gp .bounds ,
1089- self .parameter_names ,
1090- title = 'Acquisition surface' ,
1091- points = None ,
1092- axes = f .axes [1 ],
1093- ** options )
1062+ Parameters
1063+ ----------
1064+ plot_acq_pairwise : bool, optional
1065+ The option to plot the pair-wise acquisition point relationships.
10941066
1095- if options .get ('close' ):
1096- plt .close ()
1067+ """
1068+ if plot_acq_pairwise :
1069+ if len (self .parameter_names ) == 1 :
1070+ logger .info ('Can not plot the pair-wise comparison for 1 parameter.' )
1071+ return
1072+ # Transform the acquisition points in the acceptable format.
1073+ dict_pts_acq = OrderedDict ()
1074+ for idx_param , name_param in enumerate (self .parameter_names ):
1075+ dict_pts_acq [name_param ] = self .target_model .X [:, idx_param ]
1076+ vis .plot_pairs (dict_pts_acq , ** opts )
1077+ else :
1078+ if len (self .parameter_names ) == 1 :
1079+ arr_ax = vis .plot_state_1d (self , arr_ax , ** opts )
1080+ return arr_ax
1081+ elif len (self .parameter_names ) == 2 :
1082+ arr_ax = vis .plot_state_2d (self , arr_ax , ** opts )
1083+ return arr_ax
1084+ else :
1085+ logger .info ('The method is supported for 1- or 2-dimensions.' )
1086+ return
10971087
10981088 def plot_discrepancy (self , axes = None , ** kwargs ):
10991089 """Plot acquired parameters vs. resulting discrepancy.
@@ -1133,7 +1123,7 @@ class BOLFI(BayesianOptimization):
11331123
11341124 """
11351125
1136- def fit (self , n_evidence , threshold = None ):
1126+ def fit (self , n_evidence , threshold = None , ** opts ):
11371127 """Fit the surrogate model.
11381128
11391129 Generates a regression model for the discrepancy given the parameters.
@@ -1150,9 +1140,8 @@ def fit(self, n_evidence, threshold=None):
11501140
11511141 if n_evidence is None :
11521142 raise ValueError (
1153- 'You must specify the number of evidence (n_evidence) for the fitting' )
1154-
1155- self .infer (n_evidence )
1143+ 'You must specify the number of evidence( n_evidence) for the fitting' )
1144+ self .infer (n_evidence , ** opts )
11561145 return self .extract_posterior (threshold )
11571146
11581147 def extract_posterior (self , threshold = None ):
@@ -1235,12 +1224,10 @@ def sample(self,
12351224 else :
12361225 inds = np .argsort (self .target_model .Y [:, 0 ])
12371226 initials = np .asarray (self .target_model .X [inds ])
1238-
12391227 self .target_model .is_sampling = True # enables caching for default RBF kernel
12401228
12411229 tasks_ids = []
12421230 ii_initial = 0
1243-
12441231 # sampling is embarrassingly parallel, so depending on self.client this may parallelize
12451232 for ii in range (n_chains ):
12461233 seed = get_sub_seed (self .seed , ii )
@@ -1270,12 +1257,12 @@ def sample(self,
12701257
12711258 chains = np .asarray (chains )
12721259
1273- print (
1274- "{} chains of {} iterations acquired. Effective sample size and Rhat for each "
1275- "parameter:" . format (n_chains , n_samples ))
1260+ logger . info (
1261+ "%d chains of %d iterations acquired. Effective sample size and Rhat for each "
1262+ "parameter:" % (n_chains , n_samples ))
12761263 for ii , node in enumerate (self .parameter_names ):
1277- print ( node , mcmc . eff_sample_size ( chains [:, :, ii ]),
1278- mcmc .gelman_rubin ( chains [:, :, ii ] ))
1264+ chain = chains [:, :, ii ]
1265+ logger . info ( "%s %d %d" % ( node , mcmc .eff_sample_size ( chain ), mcmc . gelman_rubin ( chain ) ))
12791266
12801267 self .target_model .is_sampling = False
12811268
0 commit comments