Skip to content

Commit a0dc5de

Browse files
Merge pull request #71 from elseml/Development
Add model comparison tutorial notebooks
2 parents 168d961 + 4473772 commit a0dc5de

File tree

7 files changed

+2160
-3
lines changed

7 files changed

+2160
-3
lines changed

README.md

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ For starters, check out some of our walk-through notebooks:
1111
2. [Principled Bayesian workflow for cognitive models](docs/source/tutorial_notebooks/LCA_Model_Posterior_Estimation.ipynb)
1212
3. [Posterior estimation for ODEs](docs/source/tutorial_notebooks/Linear_ODE_system.ipynb)
1313
4. [Posterior estimation for SIR-like models](docs/source/tutorial_notebooks/Covid19_Initial_Posterior_Estimation.ipynb)
14+
5. [Model comparison for cognitive models](docs/source/tutorial_notebooks/Model_Comparison_MPT.ipynb)
15+
6. [Hierarchical model comparison for cognitive models](docs/source/tutorial_notebooks/Hierarchical_Model_Comparison_MPT.ipynb)
1416

1517
## Project Documentation
1618

@@ -177,7 +179,76 @@ preprint</em>, available for free at: https://arxiv.org/abs/2112.08866
177179

178180
## Model Comparison
179181

180-
Example coming soon...
182+
BayesFlow can not only be used for parameter estimation, but also to approximate Bayesian model comparison via posterior model probabilities or Bayes factors.
183+
184+
Let's extend the minimal example from before with a second model $M_2$ that we want to compare with our original model $M_1$:
185+
186+
```python
187+
def simulator(theta, n_obs=50, scale=1.0):
188+
return np.random.default_rng().normal(loc=theta, scale=scale, size=(n_obs, theta.shape[0]))
189+
190+
def prior_m1(D=2, mu=0., sigma=1.0):
191+
return np.random.default_rng().normal(loc=mu, scale=sigma, size=D)
192+
193+
def prior_m2(D=2, mu=2., sigma=1.0):
194+
return np.random.default_rng().normal(loc=mu, scale=sigma, size=D)
195+
```
196+
197+
We create both models as before and use a `MultiGenerativeModel` wrapper to combine them in a `meta_model`:
198+
199+
```python
200+
model_m1 = bf.simulation.GenerativeModel(prior_m1, simulator, simulator_is_batched=False)
201+
model_m2 = bf.simulation.GenerativeModel(prior_m2, simulator, simulator_is_batched=False)
202+
meta_model = bf.simulation.MultiGenerativeModel([model_m1, model_m2])
203+
```
204+
205+
Next, we construct our neural network with a `PMPNetwork` for approximating posterior model probabilities:
206+
207+
```python
208+
summary_net = bf.networks.DeepSet()
209+
probability_net = bf.networks.PMPNetwork(num_models=2)
210+
amortizer = bf.amortizers.AmortizedModelComparison(probability_net, summary_net)
211+
```
212+
213+
We combine all previous steps with a `Trainer` instance and train the neural approximator:
214+
215+
```python
216+
trainer = bf.trainers.Trainer(amortizer=amortizer, generative_model=meta_model)
217+
losses = trainer.train_online(epochs=3, iterations_per_epoch=100, batch_size=32)
218+
```
219+
220+
Let's simulate data sets from our models to check our networks' performance:
221+
222+
```python
223+
sim_data = trainer.configurator(meta_model(5000))
224+
sim_indices = sim_data["model_indices"]
225+
```
226+
227+
When feeding the data to our trained network, we almost immediately obtain posterior model probabilities for each of the 5000 data sets:
228+
229+
```python
230+
sim_preds = amortizer(sim_data)
231+
```
232+
233+
How good are these predicted probabilities? We can have a look at the calibration:
234+
235+
```python
236+
cal_curves = bf.diagnostics.plot_calibration_curves(sim_indices, sim_preds)
237+
```
238+
239+
<img src="img/showcase_calibration_curves.png" width=65% height=65%>
240+
241+
Our approximator shows excellent calibration, with the calibration curve being closely aligned to the diagonal, an expected calibration error (ECE) near 0 and most predicted probabilities being certain of the model underlying a data set. We can further assess patterns of misclassification with a confusion matrix:
242+
243+
```python
244+
conf_matrix = bf.diagnostics.plot_confusion_matrix(sim_indices, sim_preds)
245+
```
246+
247+
<img src="img/showcase_confusion_matrix.png" width=44% height=44%>
248+
249+
For the vast majority of simulated data sets, the generating model is correctly detected. With these diagnostic results backing us up, we can safely apply our trained network to empirical data.
250+
251+
BayesFlow is also able to conduct model comparison for hierarchical models. See this [tutorial notebook](docs/source/tutorial_notebooks/Hierarchical_Model_Comparison_MPT.ipynb) for an introduction to the associated workflow.
181252

182253
### References and Further Reading
183254

@@ -190,6 +261,10 @@ doi:10.1109/TNNLS.2021.3124052 available for free at: https://arxiv.org/abs/2004
190261
Bayesian Model Comparison. <em>ArXiv preprint</em>, available for free at:
191262
https://arxiv.org/abs/2210.07278
192263

264+
- Elsemüller, L., Schnuerch, M., Bürkner, P. C., & Radev, S. T. (2023). A Deep
265+
Learning Method for Comparing Bayesian Hierarchical Models. <em>ArXiv preprint</em>,
266+
available for free at: https://arxiv.org/abs/2301.11873
267+
193268
## Likelihood emulation
194269

195270
Example coming soon...

bayesflow/diagnostics.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,8 @@ def plot_confusion_matrix(
11071107
fig_size=(5, 5),
11081108
title_fontsize=18,
11091109
tick_fontsize=12,
1110+
xtick_rotation=None,
1111+
ytick_rotation=None,
11101112
normalize=True,
11111113
cmap=None,
11121114
title=True,
@@ -1124,9 +1126,13 @@ def plot_confusion_matrix(
11241126
fig_size : tuple or None, optional, default: (5, 5)
11251127
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
11261128
title_fontsize : int, optional, default: 18
1127-
The font size of the axis label texts.
1129+
The font size of the title text.
11281130
tick_fontsize : int, optional, default: 12
1129-
The font size of the axis label texts.
1131+
The font size of the axis label and model name texts.
1132+
xtick_rotation: int, optional, default: None
1133+
Rotation of x-axis tick labels (helps with long model names).
1134+
ytick_rotation: int, optional, default: None
1135+
Rotation of y-axis tick labels (helps with long model names).
11301136
normalize : bool, optional, default: True
11311137
A flag for normalization of the confusion matrix.
11321138
If True, each row of the confusion matrix is normalized to sum to 1.
@@ -1165,7 +1171,11 @@ def plot_confusion_matrix(
11651171

11661172
ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]))
11671173
ax.set_xticklabels(model_names, fontsize=tick_fontsize)
1174+
if xtick_rotation:
1175+
plt.xticks(rotation=xtick_rotation, ha="right")
11681176
ax.set_yticklabels(model_names, fontsize=tick_fontsize)
1177+
if ytick_rotation:
1178+
plt.yticks(rotation=ytick_rotation)
11691179
ax.set_xlabel("Predicted model", fontsize=tick_fontsize)
11701180
ax.set_ylabel("True model", fontsize=tick_fontsize)
11711181

docs/source/tutorial_notebooks/Hierarchical_Model_Comparison_MPT.ipynb

Lines changed: 958 additions & 0 deletions
Large diffs are not rendered by default.

docs/source/tutorial_notebooks/Model_Comparison_MPT.ipynb

Lines changed: 1114 additions & 0 deletions
Large diffs are not rendered by default.
111 KB
Loading
181 KB
Loading

img/showcase_confusion_matrix.png

71.3 KB
Loading

0 commit comments

Comments
 (0)