Skip to content

Commit 70cc501

Browse files
committed
Small changes to doc and **kwargs
1 parent 45da505 commit 70cc501

File tree

3 files changed

+20
-20
lines changed

3 files changed

+20
-20
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ the model-amortizer combination on unseen simulations:
103103
new_sims = trainer.configurator(generative_model(500))
104104

105105
# Obtain 100 posteriors draws per data set instantly
106-
posterior_draws = amortized_posterior.sample(new_sims, n_samples=250)
106+
posterior_draws = amortized_posterior.sample(new_sims, n_samples=100)
107107

108108
# Diagnoze calibration
109109
fig = bf.diagnostics.plot_sbc_histograms(posterior_draws, new_sims['parameters'])
@@ -207,7 +207,7 @@ meta_model = bf.simulation.MultiGenerativeModel([model_m1, model_m2])
207207
Next, we construct our neural network with a `PMPNetwork` for approximating posterior model probabilities:
208208

209209
```python
210-
summary_net = bf.networks.DeepSet()
210+
summary_net = bf.networks.SetTransformer(input_dim=2)
211211
probability_net = bf.networks.PMPNetwork(num_models=2)
212212
amortized_bmc = bf.amortizers.AmortizedModelComparison(probability_net, summary_net)
213213
```

bayesflow/default_settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(self, meta_dict: dict, mandatory_fields: list = []):
7373
}
7474

7575

76-
DEFAULT_SETTING_DENSE_INVARIANT = {"units": 64, "activation": "relu", "kernel_initializer": "glorot_uniform"}
76+
DEFAULT_SETTING_DENSE_DEEP_SET = {"units": 64, "activation": "relu", "kernel_initializer": "glorot_uniform"}
7777

7878

7979
DEFAULT_SETTING_DENSE_RECT = {"units": 256, "activation": "swish", "kernel_initializer": "glorot_uniform"}

bayesflow/summary_networks.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ def __init__(
198198
features from an input set using a set of seed vectors (typically one for a single summary) with ``summary_dim``
199199
output dimensions.
200200
201-
Recommnded: When using transformers as summary networks, you may want to use a smaller learning rate
202-
during training, e.g., setting ``default_lr=1e-5`` in a ``Trainer`` instance.
201+
Recommended: When using transformers as summary networks, you may want to use a smaller learning rate
202+
during training, e.g., setting ``default_lr=1e-4`` in a ``Trainer`` instance.
203203
204204
Parameters
205205
----------
@@ -211,7 +211,7 @@ def __init__(
211211
212212
``attention_settings=dict(num_heads=4, key_dim=32)``
213213
214-
You may also want to include dropout regularization in small-to-medium data regimes:
214+
You may also want to include stronger dropout regularization in small-to-medium data regimes:
215215
216216
``attention_settings=dict(num_heads=4, key_dim=32, dropout=0.1)``
217217
@@ -235,7 +235,7 @@ def __init__(
235235
The number of self-attention blocks to use before pooling.
236236
num_inducing_points : int or None, optional, default: 32
237237
The number of inducing points. Should be lower than the smallest set size.
238-
If ``None`` selected, a vanilla self-attenion block (SAB) will be used, otherwise
238+
If ``None`` selected, a vanilla self-attention block (SAB) will be used, otherwise
239239
ISAB blocks will be used. For ``num_attention_blocks > 1``, we currently recommend
240240
always using some number of inducing points.
241241
num_seeds : int, optional, default: 1
@@ -355,9 +355,9 @@ def __init__(
355355
num_dense_s1=num_dense_s1,
356356
num_dense_s2=num_dense_s2,
357357
num_dense_s3=num_dense_s3,
358-
dense_s1_args=defaults.DEFAULT_SETTING_DENSE_INVARIANT if dense_s1_args is None else dense_s1_args,
359-
dense_s2_args=defaults.DEFAULT_SETTING_DENSE_INVARIANT if dense_s2_args is None else dense_s2_args,
360-
dense_s3_args=defaults.DEFAULT_SETTING_DENSE_INVARIANT if dense_s3_args is None else dense_s3_args,
358+
dense_s1_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s1_args is None else dense_s1_args,
359+
dense_s2_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s2_args is None else dense_s2_args,
360+
dense_s3_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s3_args is None else dense_s3_args,
361361
pooling_fun=pooling_fun,
362362
)
363363

@@ -369,7 +369,7 @@ def __init__(
369369
self.out_layer = Dense(summary_dim, activation="linear")
370370
self.summary_dim = summary_dim
371371

372-
def call(self, x):
372+
def call(self, x, **kwargs):
373373
"""Performs the forward pass of a learnable deep invariant transformation consisting of
374374
a sequence of equivariant transforms followed by an invariant transform.
375375
@@ -385,10 +385,10 @@ def call(self, x):
385385
"""
386386

387387
# Pass through series of augmented equivariant transforms
388-
out_equiv = self.equiv_layers(x)
388+
out_equiv = self.equiv_layers(x, **kwargs)
389389

390390
# Pass through final invariant layer
391-
out = self.out_layer(self.inv(out_equiv))
391+
out = self.out_layer(self.inv(out_equiv, **kwargs), **kwargs)
392392

393393
return out
394394

@@ -443,7 +443,7 @@ def __init__(
443443
conv_settings : dict or None, optional, default: None
444444
The arguments passed to the `MultiConv1D` internal networks. If `None`,
445445
defaults will be used from `default_settings`. If a dictionary is provided,
446-
it should contain the followin keys:
446+
it should contain the following keys:
447447
- layer_args (dict) : arguments for `tf.keras.layers.Conv1D` without kernel_size
448448
- min_kernel_size (int) : the minimum kernel size (>= 1)
449449
- max_kernel_size (int) : the maximum kernel size
@@ -508,8 +508,8 @@ class SplitNetwork(tf.keras.Model):
508508
of data to provide an individual network for each split of the data.
509509
"""
510510

511-
def __init__(self, num_splits, split_data_configurator, network_type=InvariantNetwork, network_kwargs={}, **kwargs):
512-
"""Creates a composite network of `num_splits` sub-networks of type `network_type`, each with configuration
511+
def __init__(self, num_splits, split_data_configurator, network_type=DeepSet, network_kwargs={}, **kwargs):
512+
"""Creates a composite network of `num_splits` subnetworks of type `network_type`, each with configuration
513513
specified by `meta`.
514514
515515
Parameters
@@ -535,7 +535,7 @@ def __init__(self, num_splits, split_data_configurator, network_type=InvariantNe
535535
indicating which rows belong to the split `i`.
536536
network_type : callable, optional, default: `InvariantNetowk`
537537
Type of neural network to use.
538-
meta : dict, optional, default: {}
538+
network_kwargs : dict, optional, default: {}
539539
A dictionary containing the configuration for the networks.
540540
**kwargs
541541
Optional keyword arguments to be passed to the `tf.keras.Model` superclass.
@@ -547,7 +547,7 @@ def __init__(self, num_splits, split_data_configurator, network_type=InvariantNe
547547
self.split_data_configurator = split_data_configurator
548548
self.networks = [network_type(**network_kwargs) for _ in range(num_splits)]
549549

550-
def call(self, x):
550+
def call(self, x, **kwargs):
551551
"""Performs a forward pass through the subnetworks and concatenates their output.
552552
553553
Parameters
@@ -561,7 +561,7 @@ def call(self, x):
561561
Output of shape (batch_size, out_dim)
562562
"""
563563

564-
out = [self.networks[i](self.split_data_configurator(i, x)) for i in range(self.num_splits)]
564+
out = [self.networks[i](self.split_data_configurator(i, x), **kwargs) for i in range(self.num_splits)]
565565
out = tf.concat(out, axis=-1)
566566
return out
567567

@@ -602,7 +602,7 @@ def call(self, x, return_all=False, **kwargs):
602602
603603
Parameters
604604
----------
605-
data : tf.Tensor of shape (batch_size, ..., data_dim)
605+
x : tf.Tensor of shape (batch_size, ..., data_dim)
606606
Example, hierarchical data sets with two levels:
607607
(batch_size, D, L, x_dim) -> reduces to (batch_size, out_dim).
608608
return_all : boolean, optional, default: False

0 commit comments

Comments
 (0)