Skip to content

Commit 1e2c0ba

Browse files
committed
Add notebook [skip ci]
1 parent 7bfe011 commit 1e2c0ba

File tree

3 files changed

+77
-54
lines changed

3 files changed

+77
-54
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Welcome to our BayesFlow library for efficient simulation-based Bayesian workflo
88
For starters, check out some of our walk-through notebooks:
99

1010
1. [Quickstart amortized posterior estimation](docs/source/tutorial_notebooks/Intro_Amortized_Posterior_Estimation.ipynb)
11+
2. [Detecting model misspecification in posterior inference]((docs/source/tutorial_notebooks/Model_Misspecification.ipynb))
1112
3. [Principled Bayesian workflow for cognitive models](docs/source/tutorial_notebooks/LCA_Model_Posterior_Estimation.ipynb)
1213
4. [Posterior estimation for ODEs](docs/source/tutorial_notebooks/Linear_ODE_system.ipynb)
1314
5. [Posterior estimation for SIR-like models](docs/source/tutorial_notebooks/Covid19_Initial_Posterior_Estimation.ipynb)

bayesflow/sensitivity.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@ def misspecification_experiment(
3737
n_sim=200,
3838
configurator=None,
3939
):
40-
"""
41-
Performs a systematic sensitivity analysis with regard to 2 misspecification
42-
factors across different values of the factors provided in
40+
"""Performs a systematic sensitivity analysis with regard to two misspecification
41+
factors across different values of the factors provided in the config dictionaries.
4342
4443
Parameters
4544
----------
@@ -54,7 +53,7 @@ def misspecification_experiment(
5453
second_config_dict : dict
5554
Configuration for the second misspecification factor
5655
fields: name (str), values (1D np.ndarray)
57-
error_function: callable, default: bayesflow.computational_utilities.aggregated_rmse
56+
error_function : callable, default: bayesflow.computational_utilities.aggregated_rmse
5857
A callable that computes an error metric on the approximate posterior samples
5958
n_posterior_samples : int, optional, default: 500
6059
Number of samples from the approximate posterior per data set
@@ -63,11 +62,11 @@ def misspecification_experiment(
6362
configurator : callable or None, optional, default: None
6463
An optional configurator for the misspecified simulations.
6564
If ``None`` provided (default), ``Trainer.configurator`` will be used.
65+
6666
Returns
6767
-------
6868
posterior_error_dict: {P1, P2, value} - dictionary with misspecification grid (P1, P2) and posterior error results (values)
6969
summary_mmd: {P1, P2, values} - dictionary with misspecification grid (P1, P2) and summary MMD results (values)
70-
7170
"""
7271

7372
# Setup the grid and prepare placeholders
@@ -106,8 +105,7 @@ def misspecification_experiment(
106105

107106

108107
def plot_model_misspecification_sensitivity(results_dict, first_config_dict, second_config_dict, plot_config=None):
109-
"""
110-
Visualizes the results from a sensitivity analysis via a colored 2D grid.
108+
"""Visualizes the results from a sensitivity analysis via a colored 2D grid.
111109
112110
Parameters
113111
----------
@@ -127,7 +125,6 @@ def plot_model_misspecification_sensitivity(results_dict, first_config_dict, sec
127125
Returns
128126
-------
129127
f : plt.Figure - the figure instance for optional saving
130-
131128
"""
132129

133130
if plot_config is None:
@@ -188,50 +185,49 @@ def plot_color_grid(
188185
hline_location=None,
189186
vline_location=None,
190187
):
191-
"""
192-
Plots a 2-dimensional color grid.
188+
"""Plots a 2-dimensional color grid.
193189
194190
Parameters
195191
----------
196-
x_grid: np.ndarray
192+
x_grid : np.ndarray
197193
meshgrid of x values
198-
y_grid: np.ndarray
194+
y_grid : np.ndarray
199195
meshgrid of y values
200-
z_grid: np.ndarray
196+
z_grid : np.ndarray
201197
meshgrid of z values (coded by color in the plot)
202-
cmap: str, default: viridis
198+
cmap : str, default: viridis
203199
color map for the fill
204-
vmin: float, default: None
200+
vmin : float, default: None
205201
lower limit of the color map, None results in dynamic limit
206-
vmax: float, default: None
202+
vmax : float, default: None
207203
upper limit of the color map, None results in dynamic limit
208-
xlabel: str, default: x
209-
x label
210-
ylabel: str, default: y
211-
y label
212-
cbar_title: str, default: z
204+
xlabel : str, default: x
205+
x label text
206+
ylabel : str, default: y
207+
y label text
208+
cbar_title : str, default: z
213209
title of the color bar legend
214-
xticks: list, default: None
210+
xticks : list, default: None
215211
list of x ticks, None results in dynamic ticks
216-
yticks: list, default: None
212+
yticks : list, default: None
217213
list of y ticks, None results in dynamic ticks
218-
hline_location: float, default: None
214+
hline_location : float, default: None
219215
(optional) horizontal dashed line
220-
vline_location, float, default: None
216+
vline_location : float, default: None
221217
(optional) vertical dashed line
222218
223-
224219
Returns
225220
-------
226221
f : plt.Figure - the figure instance for optional saving
227222
"""
223+
228224
# Construct plot
229225
fig = plt.figure(figsize=(10, 5))
230226
plt.pcolor(x_grid, y_grid, z_grid, shading="nearest", rasterized=True, cmap=cmap, vmin=vmin, vmax=vmax)
231227
plt.xlabel(xlabel, fontsize=28)
232228
plt.ylabel(ylabel, fontsize=28)
233-
234229
plt.tick_params(labelsize=24)
230+
235231
if hline_location is not None:
236232
plt.axhline(y=hline_location, linestyle="--", color="lightgreen", alpha=0.80)
237233
if vline_location is not None:

docs/source/tutorial_notebooks/Model_Misspecification.ipynb

Lines changed: 53 additions & 27 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)