@@ -37,9 +37,8 @@ def misspecification_experiment(
3737 n_sim = 200 ,
3838 configurator = None ,
3939):
40- """
41- Performs a systematic sensitivity analysis with regard to 2 misspecification
42- factors across different values of the factors provided in
40+ """Performs a systematic sensitivity analysis with regard to two misspecification
41+ factors across different values of the factors provided in the config dictionaries.
4342
4443 Parameters
4544 ----------
@@ -54,7 +53,7 @@ def misspecification_experiment(
5453 second_config_dict : dict
5554 Configuration for the second misspecification factor
5655 fields: name (str), values (1D np.ndarray)
57- error_function: callable, default: bayesflow.computational_utilities.aggregated_rmse
56+ error_function : callable, default: bayesflow.computational_utilities.aggregated_rmse
5857 A callable that computes an error metric on the approximate posterior samples
5958 n_posterior_samples : int, optional, default: 500
6059 Number of samples from the approximate posterior per data set
@@ -63,11 +62,11 @@ def misspecification_experiment(
6362 configurator : callable or None, optional, default: None
6463 An optional configurator for the misspecified simulations.
6564 If ``None`` provided (default), ``Trainer.configurator`` will be used.
65+
6666 Returns
6767 -------
6868 posterior_error_dict: {P1, P2, value} - dictionary with misspecification grid (P1, P2) and posterior error results (values)
6969 summary_mmd: {P1, P2, values} - dictionary with misspecification grid (P1, P2) and summary MMD results (values)
70-
7170 """
7271
7372 # Setup the grid and prepare placeholders
@@ -106,8 +105,7 @@ def misspecification_experiment(
106105
107106
108107def plot_model_misspecification_sensitivity (results_dict , first_config_dict , second_config_dict , plot_config = None ):
109- """
110- Visualizes the results from a sensitivity analysis via a colored 2D grid.
108+ """Visualizes the results from a sensitivity analysis via a colored 2D grid.
111109
112110 Parameters
113111 ----------
@@ -127,7 +125,6 @@ def plot_model_misspecification_sensitivity(results_dict, first_config_dict, sec
127125 Returns
128126 -------
129127 f : plt.Figure - the figure instance for optional saving
130-
131128 """
132129
133130 if plot_config is None :
@@ -188,50 +185,49 @@ def plot_color_grid(
188185 hline_location = None ,
189186 vline_location = None ,
190187):
191- """
192- Plots a 2-dimensional color grid.
188+ """Plots a 2-dimensional color grid.
193189
194190 Parameters
195191 ----------
196- x_grid: np.ndarray
192+ x_grid : np.ndarray
197193 meshgrid of x values
198- y_grid: np.ndarray
194+ y_grid : np.ndarray
199195 meshgrid of y values
200- z_grid: np.ndarray
196+ z_grid : np.ndarray
201197 meshgrid of z values (coded by color in the plot)
202- cmap: str, default: viridis
198+ cmap : str, default: viridis
203199 color map for the fill
204- vmin: float, default: None
200+ vmin : float, default: None
205201 lower limit of the color map, None results in dynamic limit
206- vmax: float, default: None
202+ vmax : float, default: None
207203 upper limit of the color map, None results in dynamic limit
208- xlabel: str, default: x
209- x label
210- ylabel: str, default: y
211- y label
212- cbar_title: str, default: z
204+ xlabel : str, default: x
205+ x label text
206+ ylabel : str, default: y
207+ y label text
208+ cbar_title : str, default: z
213209 title of the color bar legend
214- xticks: list, default: None
210+ xticks : list, default: None
215211 list of x ticks, None results in dynamic ticks
216- yticks: list, default: None
212+ yticks : list, default: None
217213 list of y ticks, None results in dynamic ticks
218- hline_location: float, default: None
214+ hline_location : float, default: None
219215 (optional) horizontal dashed line
220- vline_location, float, default: None
216+ vline_location : float, default: None
221217 (optional) vertical dashed line
222218
223-
224219 Returns
225220 -------
226221 f : plt.Figure - the figure instance for optional saving
227222 """
223+
228224 # Construct plot
229225 fig = plt .figure (figsize = (10 , 5 ))
230226 plt .pcolor (x_grid , y_grid , z_grid , shading = "nearest" , rasterized = True , cmap = cmap , vmin = vmin , vmax = vmax )
231227 plt .xlabel (xlabel , fontsize = 28 )
232228 plt .ylabel (ylabel , fontsize = 28 )
233-
234229 plt .tick_params (labelsize = 24 )
230+
235231 if hline_location is not None :
236232 plt .axhline (y = hline_location , linestyle = "--" , color = "lightgreen" , alpha = 0.80 )
237233 if vline_location is not None :
0 commit comments