33__all__ = ['Rejection' , 'SMC' , 'BayesianOptimization' , 'BOLFI' ]
44
55import logging
6+ from collections import OrderedDict
67from math import ceil
78
8- import matplotlib .pyplot as plt
99import numpy as np
1010
1111import elfi .client
@@ -89,7 +89,6 @@ def __init__(self,
8989 model = model .model if isinstance (model , NodeReference ) else model
9090 if not model .parameter_names :
9191 raise ValueError ('Model {} defines no parameters' .format (model ))
92-
9392 self .model = model .copy ()
9493 self .output_names = self ._check_outputs (output_names )
9594
@@ -161,7 +160,7 @@ def extract_result(self):
161160 """
162161 raise NotImplementedError
163162
164- def update (self , batch , batch_index ):
163+ def update (self , batch , batch_index , vis = None ):
165164 """Update the inference state with a new batch.
166165
167166 ELFI calls this method when a new batch has been computed and the state of
@@ -174,10 +173,8 @@ def update(self, batch, batch_index):
174173 dict with `self.outputs` as keys and the corresponding outputs for the batch
175174 as values
176175 batch_index : int
177-
178- Returns
179- -------
180- None
176+ vis : bool, optional
177+ Interactive visualisation of the iterations.
181178
182179 """
183180 self .state ['n_batches' ] += 1
@@ -231,7 +228,7 @@ def plot_state(self, **kwargs):
231228 """
232229 raise NotImplementedError
233230
234- def infer (self , * args , vis = None , ** kwargs ):
231+ def infer (self , * args , ** options ):
235232 """Set the objective and start the iterate loop until the inference is finished.
236233
237234 See the other arguments from the `set_objective` method.
@@ -241,23 +238,16 @@ def infer(self, *args, vis=None, **kwargs):
241238 result : Sample
242239
243240 """
244- vis_opt = vis if isinstance (vis , dict ) else {}
245-
246- self .set_objective (* args , ** kwargs )
247-
241+ vis = options .pop ('vis' , None )
242+ self .set_objective (* args , ** options )
248243 while not self .finished :
249- self .iterate ()
250- if vis :
251- self .plot_state (interactive = True , ** vis_opt )
252-
244+ self .iterate (vis = vis )
253245 self .batches .cancel_pending ()
254- if vis :
255- self .plot_state (close = True , ** vis_opt )
256246
257247 return self .extract_result ()
258248
259- def iterate (self ):
260- """Advance the inference by one iteration.
249+ def iterate (self , vis = None ):
250+ """Forward the inference one iteration.
261251
262252 This is a way to manually progress the inference. One iteration consists of
263253 waiting and processing the result of the next batch in succession and possibly
@@ -272,6 +262,11 @@ def iterate(self):
272262 will never be more batches submitted in parallel than the `max_parallel_batches`
273263 setting allows.
274264
265+ Parameters
266+ ----------
267+ vis : bool, optional
268+ Interactive visualisation of the iterations.
269+
275270 Returns
276271 -------
277272 None
@@ -286,7 +281,7 @@ def iterate(self):
286281 # Handle the next ready batch in succession
287282 batch , batch_index = self .batches .wait_next ()
288283 logger .debug ('Received batch %d' % batch_index )
289- self .update (batch , batch_index )
284+ self .update (batch , batch_index , vis = vis )
290285
291286 @property
292287 def finished (self ):
@@ -466,17 +461,21 @@ def set_objective(self, n_samples, threshold=None, quantile=None, n_sim=None):
466461 # Reset the inference
467462 self .batches .reset ()
468463
469- def update (self , batch , batch_index ):
464+ def update (self , batch , batch_index , vis = None ):
470465 """Update the inference state with a new batch.
471466
472467 Parameters
473468 ----------
474469 batch : dict
475- dict with `self.outputs` as keys and the corresponding outputs for the batch
476- as values
470+ dict with `self.outputs` as keys and the corresponding outputs for the batch as values
471+ vis : bool, optional
472+ Interactive visualisation of the iterations.
477473 batch_index : int
478474
479475 """
476+ if vis and self .state ['samples' ] is not None :
477+ self .plot_state (interactive = True , ** vis )
478+
480479 super (Rejection , self ).update (batch , batch_index )
481480 if self .state ['samples' ] is None :
482481 # Lazy initialization of the outputs dict
@@ -584,8 +583,8 @@ def plot_state(self, **options):
584583 displays = []
585584 if options .get ('interactive' ):
586585 from IPython import display
587- displays . append (
588- display .HTML ('<span>Threshold: {}</span>' . format ( self . state [ 'threshold' ]) ))
586+ html_display = '<span>Threshold: {}</span>' . format ( self . state [ 'threshold' ])
587+ displays . append ( display .HTML (html_display ))
589588
590589 visin .plot_sample (
591590 self .state ['samples' ],
@@ -651,14 +650,15 @@ def extract_result(self):
651650 threshold = pop .threshold ,
652651 ** self ._extract_result_kwargs ())
653652
654- def update (self , batch , batch_index ):
653+ def update (self , batch , batch_index , vis = None ):
655654 """Update the inference state with a new batch.
656655
657656 Parameters
658657 ----------
659658 batch : dict
660- dict with `self.outputs` as keys and the corresponding outputs for the batch
661- as values
659+ dict with `self.outputs` as keys and the corresponding outputs for the batch as values
660+ vis : bool, optional
661+ Interactive visualisation of the iterations.
662662 batch_index : int
663663
664664 """
@@ -942,14 +942,16 @@ def extract_result(self):
942942 return OptimizationResult (
943943 x_min = batch_min , outputs = outputs , ** self ._extract_result_kwargs ())
944944
945- def update (self , batch , batch_index ):
945+ def update (self , batch , batch_index , vis = None ):
946946 """Update the GP regression model of the target node with a new batch.
947947
948948 Parameters
949949 ----------
950950 batch : dict
951951 dict with `self.outputs` as keys and the corresponding outputs for the batch
952952 as values
953+ vis : bool, optional
954+ Interactive visualisation of the iterations.
953955 batch_index : int
954956
955957 """
@@ -959,11 +961,22 @@ def update(self, batch, batch_index):
959961 params = batch_to_arr2d (batch , self .parameter_names )
960962 self ._report_batch (batch_index , params , batch [self .target_name ])
961963
964+ # Adding the acquisition plots.
965+ if vis and self .batches .next_index * self .batch_size > self .n_initial_evidence :
966+ options = {}
967+ options ['point_acq' ] = {'x' : params , 'd' : batch [self .target_name ]}
968+ options ['method_acq' ] = self .acquisition_method .label_fn
969+ arr_ax = self .plot_state (interactive = True , ** options )
970+
962971 optimize = self ._should_optimize ()
963972 self .target_model .update (params , batch [self .target_name ], optimize )
964973 if optimize :
965974 self .state ['last_GP_update' ] = self .target_model .n_evidence
966975
976+ # Adding the updated gp plots.
977+ if vis and self .batches .next_index * self .batch_size > self .n_initial_evidence :
978+ self .plot_state (interactive = True , arr_ax = arr_ax , ** options )
979+
967980 def prepare_new_batch (self , batch_index ):
968981 """Prepare values for a new batch.
969982
@@ -1040,60 +1053,51 @@ def _report_batch(self, batch_index, params, distances):
10401053 str += "{}{} at {}\n " .format (fill , distances [i ].item (), params [i ])
10411054 logger .debug (str )
10421055
1043- def plot_state (self , ** options ):
1044- """Plot the GP surface.
1056+ def plot_state (self , plot_acq_pairwise = False , arr_ax = None , ** options ):
1057+ """Plot the GP surface and the acquisition space .
10451058
1046- This feature is still experimental and currently supports only 2D cases.
1047- """
1048- f = plt .gcf ()
1049- if len (f .axes ) < 2 :
1050- f , _ = plt .subplots (1 , 2 , figsize = (13 , 6 ), sharex = 'row' , sharey = 'row' )
1051-
1052- gp = self .target_model
1053-
1054- # Draw the GP surface
1055- visin .draw_contour (
1056- gp .predict_mean ,
1057- gp .bounds ,
1058- self .parameter_names ,
1059- title = 'GP target surface' ,
1060- points = gp .X ,
1061- axes = f .axes [0 ],
1062- ** options )
1059+ Notes
1060+ -----
1061+ - The plots of the GP surface and the acquisition space work for the
1062+ cases when dim < 3;
1063+ - The method is experimental.
10631064
1064- # Draw the latest acquisitions
1065- if options .get ('interactive' ):
1066- point = gp .X [- 1 , :]
1067- if len (gp .X ) > 1 :
1068- f .axes [1 ].scatter (* point , color = 'red' )
1065+ Parameters
1066+ ----------
1067+ plot_acq_pairwise : bool, optional
1068+ The option to plot the pair-wise acquisition point relationships.
1069+ arr_ax : array_like, optional
1070+ Handled implicitly upon interactive visualisation.
10691071
1070- displays = [gp ._gp ]
1072+ Returns
1073+ -------
1074+ array_like
1075+ Axes for interactive visualisation.
10711076
1072- if options .get ('interactive' ):
1073- from IPython import display
1074- displays .insert (
1075- 0 ,
1076- display .HTML ('<span><b>Iteration {}:</b> Acquired {} at {}</span>' .format (
1077- len (gp .Y ), gp .Y [- 1 ][0 ], point )))
1078-
1079- # Update
1080- visin ._update_interactive (displays , options )
1081-
1082- def acq (x ):
1083- return self .acquisition_method .evaluate (x , len (gp .X ))
1084-
1085- # Draw the acquisition surface
1086- visin .draw_contour (
1087- acq ,
1088- gp .bounds ,
1089- self .parameter_names ,
1090- title = 'Acquisition surface' ,
1091- points = None ,
1092- axes = f .axes [1 ],
1093- ** options )
1077+ Raises
1078+ ------
1079+ ValueError
1080+ Unsupported dimension.
10941081
1095- if options .get ('close' ):
1096- plt .close ()
1082+ """
1083+ if plot_acq_pairwise :
1084+ if len (self .parameter_names ) == 1 :
1085+ raise ValueError ('Can not plot the pair-wise comparison for 1 parameter.' )
1086+
1087+ # Transform the acquisition points in the accepted format.
1088+ dict_pts_acq = OrderedDict ()
1089+ for idx_param , name_param in enumerate (self .parameter_names ):
1090+ dict_pts_acq [name_param ] = self .target_model .X [:, idx_param ]
1091+ vis .plot_pairs (dict_pts_acq , ** options )
1092+ else :
1093+ if len (self .parameter_names ) == 1 :
1094+ arr_ax = vis .plot_state_1d (self , arr_ax , ** options )
1095+ return arr_ax
1096+ elif len (self .parameter_names ) == 2 :
1097+ arr_ax = vis .plot_state_2d (self , arr_ax , ** options )
1098+ return arr_ax
1099+ else :
1100+ raise ValueError ('The method is supported only for 1- or 2-dimensions.' )
10971101
10981102 def plot_discrepancy (self , axes = None , ** kwargs ):
10991103 """Plot acquired parameters vs. resulting discrepancy.
@@ -1133,7 +1137,7 @@ class BOLFI(BayesianOptimization):
11331137
11341138 """
11351139
1136- def fit (self , n_evidence , threshold = None ):
1140+ def fit (self , n_evidence , threshold = None , ** options ):
11371141 """Fit the surrogate model.
11381142
11391143 Generates a regression model for the discrepancy given the parameters.
@@ -1150,9 +1154,8 @@ def fit(self, n_evidence, threshold=None):
11501154
11511155 if n_evidence is None :
11521156 raise ValueError (
1153- 'You must specify the number of evidence (n_evidence) for the fitting' )
1154-
1155- self .infer (n_evidence )
1157+ 'You must specify the number of evidence( n_evidence) for the fitting' )
1158+ self .infer (n_evidence , ** options )
11561159 return self .extract_posterior (threshold )
11571160
11581161 def extract_posterior (self , threshold = None ):
0 commit comments