Skip to content

Commit 7bfe011

Browse files
committed
Merge remote-tracking branch 'origin/Development' into Development
# Conflicts: # bayesflow/diagnostics.py
2 parents 717c941 + 7bd3b41 commit 7bfe011

File tree

8 files changed

+2172
-734
lines changed

8 files changed

+2172
-734
lines changed

README.md

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ Welcome to our BayesFlow library for efficient simulation-based Bayesian workflo
88
For starters, check out some of our walk-through notebooks:
99

1010
1. [Quickstart amortized posterior estimation](docs/source/tutorial_notebooks/Intro_Amortized_Posterior_Estimation.ipynb)
11-
2. [Principled Bayesian workflow for cognitive models](docs/source/tutorial_notebooks/LCA_Model_Posterior_Estimation.ipynb)
12-
3. [Posterior estimation for ODEs](docs/source/tutorial_notebooks/Linear_ODE_system.ipynb)
13-
4. [Posterior estimation for SIR-like models](docs/source/tutorial_notebooks/Covid19_Initial_Posterior_Estimation.ipynb)
11+
3. [Principled Bayesian workflow for cognitive models](docs/source/tutorial_notebooks/LCA_Model_Posterior_Estimation.ipynb)
12+
4. [Posterior estimation for ODEs](docs/source/tutorial_notebooks/Linear_ODE_system.ipynb)
13+
5. [Posterior estimation for SIR-like models](docs/source/tutorial_notebooks/Covid19_Initial_Posterior_Estimation.ipynb)
14+
6. [Model comparison for cognitive models](docs/source/tutorial_notebooks/Model_Comparison_MPT.ipynb)
15+
7. [Hierarchical model comparison for cognitive models](docs/source/tutorial_notebooks/Hierarchical_Model_Comparison_MPT.ipynb)
1416

1517
## Project Documentation
1618

@@ -167,7 +169,7 @@ to the `AmortizedPosterior` instance:
167169
amortizer = bf.amortizers.AmortizedPosterior(inference_net, summary_net, summary_loss_fun='MMD')
168170
```
169171

170-
The amortizer knows how to combine its losses.
172+
The amortizer knows how to combine its losses and you can inspect the summary space for outliers during inference.
171173

172174
### References and Further Reading
173175

@@ -177,7 +179,74 @@ preprint</em>, available for free at: https://arxiv.org/abs/2112.08866
177179

178180
## Model Comparison
179181

180-
Example coming soon...
182+
BayesFlow can not only be used for parameter estimation, but also to perform approximate Bayesian model comparison via posterior model probabilities or Bayes factors.
183+
Let's extend the minimal example from before with a second model $M_2$ that we want to compare with our original model $M_1$:
184+
185+
```python
186+
def simulator(theta, n_obs=50, scale=1.0):
187+
return np.random.default_rng().normal(loc=theta, scale=scale, size=(n_obs, theta.shape[0]))
188+
189+
def prior_m1(D=2, mu=0., sigma=1.0):
190+
return np.random.default_rng().normal(loc=mu, scale=sigma, size=D)
191+
192+
def prior_m2(D=2, mu=2., sigma=1.0):
193+
return np.random.default_rng().normal(loc=mu, scale=sigma, size=D)
194+
```
195+
196+
For the purpose of this illustration, the two toy models only differ with respect to their prior specification ($M_1: \mu = 0, M_2: \mu = 2$). We create both models as before and use a `MultiGenerativeModel` wrapper to combine them in a `meta_model`:
197+
198+
```python
199+
model_m1 = bf.simulation.GenerativeModel(prior_m1, simulator, simulator_is_batched=False)
200+
model_m2 = bf.simulation.GenerativeModel(prior_m2, simulator, simulator_is_batched=False)
201+
meta_model = bf.simulation.MultiGenerativeModel([model_m1, model_m2])
202+
```
203+
204+
Next, we construct our neural network with a `PMPNetwork` for approximating posterior model probabilities:
205+
206+
```python
207+
summary_net = bf.networks.DeepSet()
208+
probability_net = bf.networks.PMPNetwork(num_models=2)
209+
amortizer = bf.amortizers.AmortizedModelComparison(probability_net, summary_net)
210+
```
211+
212+
We combine all previous steps with a `Trainer` instance and train the neural approximator:
213+
214+
```python
215+
trainer = bf.trainers.Trainer(amortizer=amortizer, generative_model=meta_model)
216+
losses = trainer.train_online(epochs=3, iterations_per_epoch=100, batch_size=32)
217+
```
218+
219+
Let's simulate data sets from our models to check our networks' performance:
220+
221+
```python
222+
sims = trainer.configurator(meta_model(5000))
223+
```
224+
225+
When feeding the data to our trained network, we almost immediately obtain posterior model probabilities for each of the 5000 data sets:
226+
227+
```python
228+
model_probs = amortizer.posterior_probs(sims)
229+
```
230+
231+
How good are these predicted probabilities in the closed world? We can have a look at the calibration:
232+
233+
```python
234+
cal_curves = bf.diagnostics.plot_calibration_curves(sims["model_indices"], model_probs)
235+
```
236+
237+
<img src="img/showcase_calibration_curves.png" width=65% height=65%>
238+
239+
Our approximator shows excellent calibration, with the calibration curve being closely aligned to the diagonal, an expected calibration error (ECE) near 0 and most predicted probabilities being certain of the model underlying a data set. We can further assess patterns of misclassification with a confusion matrix:
240+
241+
```python
242+
conf_matrix = bf.diagnostics.plot_confusion_matrix(sims["model_indices"], model_probs)
243+
```
244+
245+
<img src="img/showcase_confusion_matrix.png" width=44% height=44%>
246+
247+
For the vast majority of simulated data sets, the "true" data-generating model is correctly identified. With these diagnostic results backing us up, we can proceed and apply our trained network to empirical data.
248+
249+
BayesFlow is also able to conduct model comparison for hierarchical models. See this [tutorial notebook](docs/source/tutorial_notebooks/Hierarchical_Model_Comparison_MPT.ipynb) for an introduction to the associated workflow.
181250

182251
### References and Further Reading
183252

@@ -190,6 +259,10 @@ doi:10.1109/TNNLS.2021.3124052 available for free at: https://arxiv.org/abs/2004
190259
Bayesian Model Comparison. <em>ArXiv preprint</em>, available for free at:
191260
https://arxiv.org/abs/2210.07278
192261

262+
- Elsemüller, L., Schnuerch, M., Bürkner, P. C., & Radev, S. T. (2023). A Deep
263+
Learning Method for Comparing Bayesian Hierarchical Models. <em>ArXiv preprint</em>,
264+
available for free at: https://arxiv.org/abs/2301.11873
265+
193266
## Likelihood emulation
194267

195268
Example coming soon...

bayesflow/diagnostics.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ def plot_calibration_curves(
10361036
10371037
Returns
10381038
-------
1039-
f : plt.Figure - the figure instance for optional saving
1039+
fig : plt.Figure - the figure instance for optional saving
10401040
"""
10411041

10421042
num_models = true_models.shape[-1]
@@ -1051,7 +1051,7 @@ def plot_calibration_curves(
10511051
# Initialize figure
10521052
if fig_size is None:
10531053
fig_size = (int(5 * n_col), int(5 * n_row))
1054-
f, axarr = plt.subplots(n_row, n_col, figsize=fig_size)
1054+
fig, axarr = plt.subplots(n_row, n_col, figsize=fig_size)
10551055
if n_row > 1:
10561056
ax = axarr.flat
10571057

@@ -1096,8 +1096,8 @@ def plot_calibration_curves(
10961096

10971097
# Set title
10981098
ax[j].set_title(model_names[j], fontsize=title_fontsize)
1099-
f.tight_layout()
1100-
return f
1099+
fig.tight_layout()
1100+
return fig
11011101

11021102

11031103
def plot_confusion_matrix(
@@ -1107,6 +1107,8 @@ def plot_confusion_matrix(
11071107
fig_size=(5, 5),
11081108
title_fontsize=18,
11091109
tick_fontsize=12,
1110+
xtick_rotation=None,
1111+
ytick_rotation=None,
11101112
normalize=True,
11111113
cmap=None,
11121114
title=True,
@@ -1124,9 +1126,13 @@ def plot_confusion_matrix(
11241126
fig_size : tuple or None, optional, default: (5, 5)
11251127
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
11261128
title_fontsize : int, optional, default: 18
1127-
The font size of the axis label texts.
1129+
The font size of the title text.
11281130
tick_fontsize : int, optional, default: 12
1129-
The font size of the axis label texts.
1131+
The font size of the axis label and model name texts.
1132+
xtick_rotation: int, optional, default: None
1133+
Rotation of x-axis tick labels (helps with long model names).
1134+
ytick_rotation: int, optional, default: None
1135+
Rotation of y-axis tick labels (helps with long model names).
11301136
normalize : bool, optional, default: True
11311137
A flag for normalization of the confusion matrix.
11321138
If True, each row of the confusion matrix is normalized to sum to 1.
@@ -1135,6 +1141,10 @@ def plot_confusion_matrix(
11351141
e.g., 'viridis'. Default colormap matches the BayesFlow defaults by ranging from white to red.
11361142
title : bool, optional, default True
11371143
A flag for adding 'Confusion Matrix' above the matrix.
1144+
1145+
Returns
1146+
-------
1147+
fig : plt.Figure - the figure instance for optional saving
11381148
"""
11391149

11401150
if model_names is None:
@@ -1154,14 +1164,18 @@ def plot_confusion_matrix(
11541164
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
11551165

11561166
# Initialize figure
1157-
f, ax = plt.subplots(1, 1, figsize=fig_size)
1167+
fig, ax = plt.subplots(1, 1, figsize=fig_size)
11581168

11591169
im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
11601170
ax.figure.colorbar(im, ax=ax, shrink=0.7)
11611171

11621172
ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]))
11631173
ax.set_xticklabels(model_names, fontsize=tick_fontsize)
1174+
if xtick_rotation:
1175+
plt.xticks(rotation=xtick_rotation, ha="right")
11641176
ax.set_yticklabels(model_names, fontsize=tick_fontsize)
1177+
if ytick_rotation:
1178+
plt.yticks(rotation=ytick_rotation)
11651179
ax.set_xlabel("Predicted model", fontsize=tick_fontsize)
11661180
ax.set_ylabel("True model", fontsize=tick_fontsize)
11671181

@@ -1175,6 +1189,7 @@ def plot_confusion_matrix(
11751189
)
11761190
if title:
11771191
ax.set_title("Confusion Matrix", fontsize=title_fontsize)
1192+
return fig
11781193

11791194

11801195
def plot_mmd_hypothesis_test(mmd_null,

0 commit comments

Comments
 (0)