Skip to content

Commit 44415f4

Browse files
committed
Update docs and typehints
1 parent b82716b commit 44415f4

File tree

3 files changed

+52
-4
lines changed

3 files changed

+52
-4
lines changed

bayesflow/networks/inference_network.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
import keras
1+
from typing import Literal
22
from collections.abc import Sequence
33

4+
import keras
5+
46
from bayesflow.types import Shape, Tensor
57
from bayesflow.utils import layer_kwargs, find_distribution
68
from bayesflow.utils.decorators import allow_batch_size
@@ -9,7 +11,28 @@
911

1012
@serializable("bayesflow.networks")
1113
class InferenceNetwork(keras.Layer):
12-
def __init__(self, base_distribution: str = "normal", *, metrics: Sequence[keras.Metric] | None = None, **kwargs):
14+
def __init__(
15+
self,
16+
base_distribution: Literal["normal", "student", "mixture"] | keras.Layer = "normal",
17+
*,
18+
metrics: Sequence[keras.Metric] | None = None,
19+
**kwargs,
20+
):
21+
"""
22+
Constructs an inference network using a specified base distribution and optional custom metrics.
23+
Use this interface for custom inference networks.
24+
25+
Parameters
26+
----------
27+
base_distribution : Literal["normal", "student", "mixture"] or keras.Layer
28+
Name or the actual base distribution to use. Passed to `find_distribution` to
29+
obtain the corresponding distribution object.
30+
metrics : Sequence[keras.Metric] or None, optional
31+
Sequence of custom Keras Metric instances to compute during training
32+
and evaluation. If `None`, no custom metrics are used.
33+
**kwargs
34+
Additional keyword arguments forwarded to the `keras.Layer` constructor.
35+
"""
1336
self.custom_metrics = metrics
1437
super().__init__(**layer_kwargs(kwargs))
1538
self.base_distribution = find_distribution(base_distribution)
@@ -70,7 +93,7 @@ def compute_metrics(
7093

7194
if stage != "training" and any(self.metrics):
7295
# compute sample-based metrics
73-
samples = self.sample((keras.ops.shape(x)[0],), conditions=conditions)
96+
samples = self.sample(batch_shape=(keras.ops.shape(x)[0],), conditions=conditions)
7497

7598
for metric in self.metrics:
7699
metrics[metric.name] = metric(samples, x)

bayesflow/networks/mlp/mlp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,11 @@ def __init__(
5555
dropout : float or None, optional
5656
Dropout rate applied within the MLP layers for regularization. Default is 0.05.
5757
norm: str, optional
58-
58+
Type of learnable normalization to be used (e.g., "batch" or "layer"). Default is None.
5959
spectral_normalization : bool, optional
6060
Whether to apply spectral normalization to stabilize training. Default is False.
61+
metrics: Sequence[keras.Metric], optional
62+
A sequence of callable metrics following keras' `Metric` signature. Default is None.
6163
**kwargs
6264
Additional keyword arguments passed to the Keras layer initialization.
6365
"""

bayesflow/networks/summary_network.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,29 @@
1111
@serializable("bayesflow.networks")
1212
class SummaryNetwork(keras.Layer):
1313
def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Metric] | None = None, **kwargs):
14+
"""
15+
Builds a summary network with an optional base distribution and custom metrics. Use this class
16+
as an interface for custom summary networks.
17+
18+
Important: If a base distribution is passed, the summary outputs will be optimized to follow
19+
said distribution, as described in [1].
20+
21+
[1] Schmitt, M., Bürkner, P. C., Köthe, U., & Radev, S. T. (2023).
22+
Detecting model misspecification in amortized Bayesian inference with neural networks.
23+
In DAGM German Conference on Pattern Recognition (pp. 541-557). Cham: Springer Nature Switzerland.
24+
25+
Parameters
26+
----------
27+
base_distribution : str or None, default None
28+
Name of the base distribution to use. If `None`, a default distribution
29+
is chosen. Passed to `find_distribution` to obtain the corresponding
30+
distribution object.
31+
metrics : Sequence[keras.Metric] or None, optional
32+
Sequence of custom Keras Metric instances to compute during training
33+
and evaluation. If `None`, no custom metrics are used.
34+
**kwargs
35+
Additional keyword arguments forwarded to the `keras.Layer` constructor.
36+
"""
1437
self.custom_metrics = metrics
1538
super().__init__(**layer_kwargs(kwargs))
1639
self.base_distribution = find_distribution(base_distribution)

0 commit comments

Comments
 (0)