|
1 | | -import keras |
| 1 | +from typing import Literal |
2 | 2 | from collections.abc import Sequence |
3 | 3 |
|
| 4 | +import keras |
| 5 | + |
4 | 6 | from bayesflow.types import Shape, Tensor |
5 | 7 | from bayesflow.utils import layer_kwargs, find_distribution |
6 | 8 | from bayesflow.utils.decorators import allow_batch_size |
|
9 | 11 |
|
10 | 12 | @serializable("bayesflow.networks") |
11 | 13 | class InferenceNetwork(keras.Layer): |
12 | | - def __init__(self, base_distribution: str = "normal", *, metrics: Sequence[keras.Metric] | None = None, **kwargs): |
| 14 | + def __init__( |
| 15 | + self, |
| 16 | + base_distribution: Literal["normal", "student", "mixture"] | keras.Layer = "normal", |
| 17 | + *, |
| 18 | + metrics: Sequence[keras.Metric] | None = None, |
| 19 | + **kwargs, |
| 20 | + ): |
| 21 | + """ |
| 22 | + Constructs an inference network using a specified base distribution and optional custom metrics. |
| 23 | + Use this interface for custom inference networks. |
| 24 | +
|
| 25 | + Parameters |
| 26 | + ---------- |
| 27 | + base_distribution : Literal["normal", "student", "mixture"] or keras.Layer |
| 28 | + Name or the actual base distribution to use. Passed to `find_distribution` to |
| 29 | + obtain the corresponding distribution object. |
| 30 | + metrics : Sequence[keras.Metric] or None, optional |
| 31 | + Sequence of custom Keras Metric instances to compute during training |
| 32 | + and evaluation. If `None`, no custom metrics are used. |
| 33 | + **kwargs |
| 34 | + Additional keyword arguments forwarded to the `keras.Layer` constructor. |
| 35 | + """ |
13 | 36 | self.custom_metrics = metrics |
14 | 37 | super().__init__(**layer_kwargs(kwargs)) |
15 | 38 | self.base_distribution = find_distribution(base_distribution) |
@@ -70,7 +93,7 @@ def compute_metrics( |
70 | 93 |
|
71 | 94 | if stage != "training" and any(self.metrics): |
72 | 95 | # compute sample-based metrics |
73 | | - samples = self.sample((keras.ops.shape(x)[0],), conditions=conditions) |
| 96 | + samples = self.sample(batch_shape=(keras.ops.shape(x)[0],), conditions=conditions) |
74 | 97 |
|
75 | 98 | for metric in self.metrics: |
76 | 99 | metrics[metric.name] = metric(samples, x) |
|
0 commit comments