Skip to content

Commit fda4479

Browse files
Merge pull request #103 from stefanradev93/Development
Development
2 parents 2bd8744 + 7214f78 commit fda4479

File tree

5 files changed

+121
-41
lines changed

5 files changed

+121
-41
lines changed

CITATION.cff

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
cff-version: "1.2.0"
2+
authors:
3+
- family-names: Radev
4+
given-names: Stefan T.
5+
orcid: "https://orcid.org/0000-0002-6702-9559"
6+
- family-names: Schmitt
7+
given-names: Marvin
8+
orcid: "https://orcid.org/0000-0003-1293-820X"
9+
- family-names: Schumacher
10+
given-names: Lukas
11+
orcid: "https://orcid.org/0000-0003-1512-8288"
12+
- family-names: Elsemüller
13+
given-names: Lasse
14+
orcid: "https://orcid.org/0000-0003-0368-720X"
15+
- family-names: Pratz
16+
given-names: Valentin
17+
orcid: "https://orcid.org/0000-0001-8371-3417"
18+
- family-names: Schälte
19+
given-names: Yannik
20+
orcid: "https://orcid.org/0000-0003-1293-820X"
21+
- family-names: Köthe
22+
given-names: Ullrich
23+
orcid: "https://orcid.org/0000-0001-6036-1287"
24+
- family-names: Bürkner
25+
given-names: Paul-Christian
26+
orcid: "https://orcid.org/0000-0001-5765-8995"
27+
contact:
28+
- family-names: Radev
29+
given-names: Stefan T.
30+
orcid: "https://orcid.org/0000-0002-6702-9559"
31+
doi: 10.5281/zenodo.8346393
32+
message: If you use this software, please cite our article in the
33+
Journal of Open Source Software.
34+
preferred-citation:
35+
authors:
36+
- family-names: Radev
37+
given-names: Stefan T.
38+
orcid: "https://orcid.org/0000-0002-6702-9559"
39+
- family-names: Schmitt
40+
given-names: Marvin
41+
orcid: "https://orcid.org/0000-0003-1293-820X"
42+
- family-names: Schumacher
43+
given-names: Lukas
44+
orcid: "https://orcid.org/0000-0003-1512-8288"
45+
- family-names: Elsemüller
46+
given-names: Lasse
47+
orcid: "https://orcid.org/0000-0003-0368-720X"
48+
- family-names: Pratz
49+
given-names: Valentin
50+
orcid: "https://orcid.org/0000-0001-8371-3417"
51+
- family-names: Schälte
52+
given-names: Yannik
53+
orcid: "https://orcid.org/0000-0003-1293-820X"
54+
- family-names: Köthe
55+
given-names: Ullrich
56+
orcid: "https://orcid.org/0000-0001-6036-1287"
57+
- family-names: Bürkner
58+
given-names: Paul-Christian
59+
orcid: "https://orcid.org/0000-0001-5765-8995"
60+
date-published: 2023-09-22
61+
doi: 10.21105/joss.05702
62+
issn: 2475-9066
63+
issue: 89
64+
journal: Journal of Open Source Software
65+
publisher:
66+
name: Open Journals
67+
start: 5702
68+
title: "BayesFlow: Amortized Bayesian Workflows With Neural Networks"
69+
type: article
70+
url: "https://joss.theoj.org/papers/10.21105/joss.05702"
71+
volume: 8
72+
title: "BayesFlow: Amortized Bayesian Workflows With Neural Networks"

README.md

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
[![Actions Status](https://github.com/stefanradev93/bayesflow/workflows/Tests/badge.svg)](https://github.com/stefanradev93/bayesflow/actions)
44
[![Licence](https://img.shields.io/github/license/stefanradev93/BayesFlow)](https://img.shields.io/github/license/stefanradev93/BayesFlow)
5+
[![DOI](https://joss.theoj.org/papers/10.21105/joss.05702/status.svg)](https://doi.org/10.21105/joss.05702)
56

67
Welcome to our BayesFlow library for efficient simulation-based Bayesian workflows! Our library enables users to create specialized neural networks for *amortized Bayesian inference*, which repay users with rapid statistical inference after a potentially longer simulation-based training phase.
78

@@ -76,51 +77,52 @@ generative_model = bf.simulation.GenerativeModel(prior, simulator)
7677
Next, we create our BayesFlow setup consisting of a summary and an inference network:
7778

7879
```python
79-
summary_net = bf.networks.DeepSet()
80+
summary_net = bf.networks.SetTransformer(input_dim=2)
8081
inference_net = bf.networks.InvertibleNetwork(num_params=2)
8182
amortized_posterior = bf.amortizers.AmortizedPosterior(inference_net, summary_net)
8283
```
8384

8485
Finally, we connect the networks with the generative model via a `Trainer` instance:
8586

8687
```python
87-
trainer = bf.trainers.Trainer(amortizer=amortized_posterior, generative_model=generative_model, memory=True)
88+
trainer = bf.trainers.Trainer(amortizer=amortized_posterior, generative_model=generative_model)
8889
```
8990

9091
We are now ready to train an amortized posterior approximator. For instance,
9192
to run online training, we simply call:
9293

9394
```python
94-
losses = trainer.train_online(epochs=10, iterations_per_epoch=500, batch_size=32)
95+
losses = trainer.train_online(epochs=10, iterations_per_epoch=1000, batch_size=32)
9596
```
9697

97-
Before inference, we can use simulation-based calibration (SBC,
98+
Prior to inference, we can use simulation-based calibration (SBC,
9899
https://arxiv.org/abs/1804.06788) to check the computational faithfulness of
99-
the model-amortizer combination:
100+
the model-amortizer combination on unseen simulations:
100101

101102
```python
102-
fig = trainer.diagnose_sbc_histograms()
103+
# Generate 500 new simulated data sets
104+
new_sims = trainer.configurator(generative_model(500))
105+
106+
# Obtain 100 posteriors draws per data set instantly
107+
posterior_draws = amortized_posterior.sample(new_sims, n_samples=100)
108+
109+
# Diagnoze calibration
110+
fig = bf.diagnostics.plot_sbc_histograms(posterior_draws, new_sims['parameters'])
103111
```
104112

105113
<img src="https://github.com/stefanradev93/BayesFlow/blob/master/img/showcase_sbc.png?raw=true" width=65% height=65%>
106114

107115
The histograms are roughly uniform and lie within the expected range for
108116
well-calibrated inference algorithms as indicated by the shaded gray areas.
109-
Accordingly, our amortizer seems to have converged to the intended target.
110-
111-
Amortized inference on new (real or simulated) data is then easy and fast.
112-
For example, we can simulate 200 new data sets and generate 500 posterior draws
113-
per data set:
117+
Accordingly, our neural approximator seems to have converged to the intended target.
114118

115-
```python
116-
new_sims = trainer.configurator(generative_model(200))
117-
posterior_draws = amortized_posterior.sample(new_sims, n_samples=500)
118-
```
119+
As you can see, amortized inference on new (real or simulated) data is easy and fast.
120+
We can obtain further 5000 posterior draws per simulated data set and quickly inspect
121+
how well the model can recover its parameters across the entire *prior predictive distribution*.
119122

120-
We can then quickly inspect the how well the model can recover its parameters
121-
across the simulated data sets.
122123

123124
```python
125+
posterior_draws = amortized_posterior.sample(new_sims, n_samples=5000)
124126
fig = bf.diagnostics.plot_recovery(posterior_draws, new_sims['parameters'])
125127
```
126128

@@ -162,7 +164,7 @@ A modified loss function optimizes the learned summary statistics towards a unit
162164
Gaussian and reliably detects model misspecification during inference time.
163165

164166

165-
<img src="https://github.com/stefanradev93/BayesFlow/blob/master/examples/img/model_misspecification_amortized_sbi.png" width=100% height=100%>
167+
<img src="https://github.com/stefanradev93/BayesFlow/blob/master/examples/img/model_misspecification_amortized_sbi.png?raw=true" width=100% height=100%>
166168

167169
In order to use this method, you should only provide the `summary_loss_fun` argument
168170
to the `AmortizedPosterior` instance:
@@ -206,7 +208,7 @@ meta_model = bf.simulation.MultiGenerativeModel([model_m1, model_m2])
206208
Next, we construct our neural network with a `PMPNetwork` for approximating posterior model probabilities:
207209

208210
```python
209-
summary_net = bf.networks.DeepSet()
211+
summary_net = bf.networks.SetTransformer(input_dim=2)
210212
probability_net = bf.networks.PMPNetwork(num_models=2)
211213
amortized_bmc = bf.amortizers.AmortizedModelComparison(probability_net, summary_net)
212214
```

bayesflow/default_settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(self, meta_dict: dict, mandatory_fields: list = []):
7373
}
7474

7575

76-
DEFAULT_SETTING_DENSE_INVARIANT = {"units": 64, "activation": "relu", "kernel_initializer": "glorot_uniform"}
76+
DEFAULT_SETTING_DENSE_DEEP_SET = {"units": 64, "activation": "relu", "kernel_initializer": "glorot_uniform"}
7777

7878

7979
DEFAULT_SETTING_DENSE_RECT = {"units": 256, "activation": "swish", "kernel_initializer": "glorot_uniform"}

bayesflow/diagnostics.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def plot_recovery(
5151
color="#8f2727",
5252
n_col=None,
5353
n_row=None,
54+
xlabel="Ground truth",
55+
ylabel="Estimated",
5456
):
5557
"""Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty.
5658
The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate
@@ -96,7 +98,11 @@ def plot_recovery(
9698
A flag for adding R^2 between true and estimates to the plot
9799
color : str, optional, default: '#8f2727'
98100
The color for the true vs. estimated scatter points and error bars
99-
101+
xlabel : str, optional, default: 'Ground truth'
102+
The label on the x-axis of the plot
103+
ylabel : str, optional, default: 'Estimated'
104+
The label on the y-axis of the plot
105+
100106
Returns
101107
-------
102108
f : plt.Figure - the figure instance for optional saving
@@ -198,15 +204,15 @@ def plot_recovery(
198204
# Only add x-labels to the bottom row
199205
bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :]
200206
for _ax in bottom_row:
201-
_ax.set_xlabel("Ground truth", fontsize=label_fontsize)
207+
_ax.set_xlabel(xlabel, fontsize=label_fontsize)
202208

203209
# Only add y-labels to right left-most row
204210
if n_row == 1: # if there is only one row, the ax array is 1D
205-
axarr[0].set_ylabel("Estimated", fontsize=label_fontsize)
211+
axarr[0].set_ylabel(ylabel, fontsize=label_fontsize)
206212
# If there is more than one row, the ax array is 2D
207213
else:
208214
for _ax in axarr[:, 0]:
209-
_ax.set_ylabel("Estimated", fontsize=label_fontsize)
215+
_ax.set_ylabel(ylabel, fontsize=label_fontsize)
210216

211217
# Remove unused axes entirely
212218
for _ax in axarr_it[n_params:]:

bayesflow/summary_networks.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ def __init__(
198198
features from an input set using a set of seed vectors (typically one for a single summary) with ``summary_dim``
199199
output dimensions.
200200
201-
Recommnded: When using transformers as summary networks, you may want to use a smaller learning rate
202-
during training, e.g., setting ``default_lr=1e-5`` in a ``Trainer`` instance.
201+
Recommended: When using transformers as summary networks, you may want to use a smaller learning rate
202+
during training, e.g., setting ``default_lr=1e-4`` in a ``Trainer`` instance.
203203
204204
Parameters
205205
----------
@@ -211,7 +211,7 @@ def __init__(
211211
212212
``attention_settings=dict(num_heads=4, key_dim=32)``
213213
214-
You may also want to include dropout regularization in small-to-medium data regimes:
214+
You may also want to include stronger dropout regularization in small-to-medium data regimes:
215215
216216
``attention_settings=dict(num_heads=4, key_dim=32, dropout=0.1)``
217217
@@ -235,7 +235,7 @@ def __init__(
235235
The number of self-attention blocks to use before pooling.
236236
num_inducing_points : int or None, optional, default: 32
237237
The number of inducing points. Should be lower than the smallest set size.
238-
If ``None`` selected, a vanilla self-attenion block (SAB) will be used, otherwise
238+
If ``None`` selected, a vanilla self-attention block (SAB) will be used, otherwise
239239
ISAB blocks will be used. For ``num_attention_blocks > 1``, we currently recommend
240240
always using some number of inducing points.
241241
num_seeds : int, optional, default: 1
@@ -355,9 +355,9 @@ def __init__(
355355
num_dense_s1=num_dense_s1,
356356
num_dense_s2=num_dense_s2,
357357
num_dense_s3=num_dense_s3,
358-
dense_s1_args=defaults.DEFAULT_SETTING_DENSE_INVARIANT if dense_s1_args is None else dense_s1_args,
359-
dense_s2_args=defaults.DEFAULT_SETTING_DENSE_INVARIANT if dense_s2_args is None else dense_s2_args,
360-
dense_s3_args=defaults.DEFAULT_SETTING_DENSE_INVARIANT if dense_s3_args is None else dense_s3_args,
358+
dense_s1_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s1_args is None else dense_s1_args,
359+
dense_s2_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s2_args is None else dense_s2_args,
360+
dense_s3_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s3_args is None else dense_s3_args,
361361
pooling_fun=pooling_fun,
362362
)
363363

@@ -369,7 +369,7 @@ def __init__(
369369
self.out_layer = Dense(summary_dim, activation="linear")
370370
self.summary_dim = summary_dim
371371

372-
def call(self, x):
372+
def call(self, x, **kwargs):
373373
"""Performs the forward pass of a learnable deep invariant transformation consisting of
374374
a sequence of equivariant transforms followed by an invariant transform.
375375
@@ -385,10 +385,10 @@ def call(self, x):
385385
"""
386386

387387
# Pass through series of augmented equivariant transforms
388-
out_equiv = self.equiv_layers(x)
388+
out_equiv = self.equiv_layers(x, **kwargs)
389389

390390
# Pass through final invariant layer
391-
out = self.out_layer(self.inv(out_equiv))
391+
out = self.out_layer(self.inv(out_equiv, **kwargs), **kwargs)
392392

393393
return out
394394

@@ -443,7 +443,7 @@ def __init__(
443443
conv_settings : dict or None, optional, default: None
444444
The arguments passed to the `MultiConv1D` internal networks. If `None`,
445445
defaults will be used from `default_settings`. If a dictionary is provided,
446-
it should contain the followin keys:
446+
it should contain the following keys:
447447
- layer_args (dict) : arguments for `tf.keras.layers.Conv1D` without kernel_size
448448
- min_kernel_size (int) : the minimum kernel size (>= 1)
449449
- max_kernel_size (int) : the maximum kernel size
@@ -508,8 +508,8 @@ class SplitNetwork(tf.keras.Model):
508508
of data to provide an individual network for each split of the data.
509509
"""
510510

511-
def __init__(self, num_splits, split_data_configurator, network_type=InvariantNetwork, network_kwargs={}, **kwargs):
512-
"""Creates a composite network of `num_splits` sub-networks of type `network_type`, each with configuration
511+
def __init__(self, num_splits, split_data_configurator, network_type=DeepSet, network_kwargs={}, **kwargs):
512+
"""Creates a composite network of `num_splits` subnetworks of type `network_type`, each with configuration
513513
specified by `meta`.
514514
515515
Parameters
@@ -535,7 +535,7 @@ def __init__(self, num_splits, split_data_configurator, network_type=InvariantNe
535535
indicating which rows belong to the split `i`.
536536
network_type : callable, optional, default: `InvariantNetowk`
537537
Type of neural network to use.
538-
meta : dict, optional, default: {}
538+
network_kwargs : dict, optional, default: {}
539539
A dictionary containing the configuration for the networks.
540540
**kwargs
541541
Optional keyword arguments to be passed to the `tf.keras.Model` superclass.
@@ -547,7 +547,7 @@ def __init__(self, num_splits, split_data_configurator, network_type=InvariantNe
547547
self.split_data_configurator = split_data_configurator
548548
self.networks = [network_type(**network_kwargs) for _ in range(num_splits)]
549549

550-
def call(self, x):
550+
def call(self, x, **kwargs):
551551
"""Performs a forward pass through the subnetworks and concatenates their output.
552552
553553
Parameters
@@ -561,7 +561,7 @@ def call(self, x):
561561
Output of shape (batch_size, out_dim)
562562
"""
563563

564-
out = [self.networks[i](self.split_data_configurator(i, x)) for i in range(self.num_splits)]
564+
out = [self.networks[i](self.split_data_configurator(i, x), **kwargs) for i in range(self.num_splits)]
565565
out = tf.concat(out, axis=-1)
566566
return out
567567

@@ -602,7 +602,7 @@ def call(self, x, return_all=False, **kwargs):
602602
603603
Parameters
604604
----------
605-
data : tf.Tensor of shape (batch_size, ..., data_dim)
605+
x : tf.Tensor of shape (batch_size, ..., data_dim)
606606
Example, hierarchical data sets with two levels:
607607
(batch_size, D, L, x_dim) -> reduces to (batch_size, out_dim).
608608
return_all : boolean, optional, default: False

0 commit comments

Comments
 (0)