Skip to content

Commit 12a6d4b

Browse files
committed
add get_config for network base classes
- get config has to be manually specified in the base classes, so that the config is stored even when to subclass overrides get_config - to preserve the auto_config behavior, we have to use the `python_utils.default` decorator from, which marks them as default methods. This allows detecting if a subclass has overridden them. This is the same mechanism that Keras uses - moved setting the `custom_metrics` parameter after the `super().__init__` calls, as the tracking is managed in setattr - extended some tests to use metrics
1 parent 5f38e86 commit 12a6d4b

File tree

8 files changed

+41
-11
lines changed

8 files changed

+41
-11
lines changed

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
141141
def get_config(self):
142142
base_config = super().get_config()
143143

144+
# base distribution is passed manually to InferenceNetwork parent class, do not store it here
145+
base_config.pop("base_distribution")
146+
144147
config = {
145148
"subnet": self.subnet,
146149
"noise_schedule": self.noise_schedule,

bayesflow/experimental/free_form_flow/free_form_flow.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
find_network,
99
jacobian,
1010
jvp,
11-
model_kwargs,
1211
vjp,
1312
weighted_mean,
1413
)
@@ -240,7 +239,6 @@ def from_config(cls, config, custom_objects=None):
240239

241240
def get_config(self):
242241
base_config = super().get_config()
243-
base_config = model_kwargs(base_config)
244242

245243
config = {
246244
"beta": self.beta,

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ def from_config(cls, config, custom_objects=None):
110110
def get_config(self):
111111
base_config = super().get_config()
112112

113+
# base distribution is passed manually to InferenceNetwork parent class, do not store it here
114+
base_config.pop("base_distribution")
115+
113116
config = {
114117
"total_steps": self.total_steps,
115118
"subnet": self.subnet,

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(
9090
Keyword arguments forwarded to the affine or spline transforms
9191
(e.g., bins for splines)
9292
**kwargs
93-
Additional keyword arguments passed to `InvertibleLayer`.
93+
Additional keyword arguments passed to `InferenceNetwork`.
9494
9595
"""
9696
super().__init__(base_distribution=base_distribution, **kwargs)

bayesflow/networks/inference_network.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from collections.abc import Sequence
33

44
import keras
5+
from keras.src.utils import python_utils
56

67
from bayesflow.types import Shape, Tensor
78
from bayesflow.utils import layer_kwargs, find_distribution
89
from bayesflow.utils.decorators import allow_batch_size
9-
from bayesflow.utils.serialization import serializable
10+
from bayesflow.utils.serialization import serializable, serialize
1011

1112

1213
@serializable("bayesflow.networks")
@@ -38,8 +39,8 @@ def __init__(
3839
**kwargs
3940
Additional keyword arguments forwarded to the `keras.Layer` constructor.
4041
"""
41-
self.custom_metrics = metrics
4242
super().__init__(**layer_kwargs(kwargs))
43+
self.custom_metrics = metrics
4344
self.base_distribution = find_distribution(base_distribution)
4445

4546
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
@@ -104,3 +105,10 @@ def compute_metrics(
104105
metrics[metric.name] = metric(samples, x)
105106

106107
return metrics
108+
109+
@python_utils.default
110+
def get_config(self):
111+
base_config = super().get_config()
112+
113+
config = {"metrics": self.custom_metrics, "base_distribution": self.base_distribution}
114+
return base_config | serialize(config)

bayesflow/networks/point_inference_network.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from collections.abc import Sequence
12
import keras
3+
from keras.src.utils import python_utils
24

35
from bayesflow.utils import model_kwargs, find_network
46
from bayesflow.utils.serialization import deserialize, serializable, serialize
@@ -17,9 +19,12 @@ def __init__(
1719
self,
1820
scores: dict[str, ScoringRule],
1921
subnet: str | keras.Layer = "mlp",
22+
*,
23+
metrics: Sequence[keras.Metric] | None = None,
2024
**kwargs,
2125
):
2226
super().__init__(**model_kwargs(kwargs))
27+
self.custom_metrics = metrics
2328

2429
self.scores = scores
2530

@@ -28,6 +33,7 @@ def __init__(
2833
self.config = {
2934
"subnet": serialize(subnet),
3035
"scores": serialize(scores),
36+
"metrics": serialize(metrics),
3137
**kwargs,
3238
}
3339

@@ -106,6 +112,7 @@ def build_from_config(self, config):
106112
for head_key, head in self.heads[score_key].items():
107113
head.name = config["heads"][score_key][head_key]
108114

115+
@python_utils.default
109116
def get_config(self):
110117
base_config = super().get_config()
111118

@@ -114,9 +121,7 @@ def get_config(self):
114121
@classmethod
115122
def from_config(cls, config):
116123
config = config.copy()
117-
config["scores"] = deserialize(config["scores"])
118-
config["subnet"] = deserialize(config["subnet"])
119-
return cls(**config)
124+
return cls(**deserialize(config))
120125

121126
def call(
122127
self,

bayesflow/networks/summary_network.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import keras
2+
from keras.src.utils import python_utils
23
from collections.abc import Sequence
34

45
from bayesflow.metrics.functional import maximum_mean_discrepancy
56
from bayesflow.types import Tensor
67
from bayesflow.utils import layer_kwargs, find_distribution
78
from bayesflow.utils.decorators import sanitize_input_shape
8-
from bayesflow.utils.serialization import serializable
9+
from bayesflow.utils.serialization import serializable, serialize
910

1011

1112
@serializable("bayesflow.networks")
@@ -44,8 +45,8 @@ def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Met
4445
**kwargs
4546
Additional keyword arguments forwarded to the `keras.Layer` constructor.
4647
"""
47-
self.custom_metrics = metrics
4848
super().__init__(**layer_kwargs(kwargs))
49+
self.custom_metrics = metrics
4950
self.base_distribution = find_distribution(base_distribution)
5051

5152
@sanitize_input_shape
@@ -86,3 +87,10 @@ def compute_metrics(self, x: Tensor, stage: str = "training", **kwargs) -> dict[
8687
metrics[metric.name] = metric(outputs, samples)
8788

8889
return metrics
90+
91+
@python_utils.default
92+
def get_config(self):
93+
base_config = super().get_config()
94+
95+
config = {"metrics": self.custom_metrics, "base_distribution": self.base_distribution}
96+
return base_config | serialize(config)

tests/test_approximators/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def continuous_approximator(adapter, inference_network, summary_network):
4040

4141
@pytest.fixture()
4242
def point_inference_network():
43+
from bayesflow.metrics import RootMeanSquaredError
4344
from bayesflow.networks import PointInferenceNetwork
4445
from bayesflow.scores import NormedDifferenceScore, QuantileScore, MultivariateNormalScore
4546

@@ -51,11 +52,13 @@ def point_inference_network():
5152
),
5253
subnet="mlp",
5354
subnet_kwargs=dict(widths=(32, 32)),
55+
metrics=[RootMeanSquaredError(name="rmse")],
5456
)
5557

5658

5759
@pytest.fixture()
5860
def point_inference_network_with_multiple_parametric_scores():
61+
from bayesflow.metrics import RootMeanSquaredError
5962
from bayesflow.networks import PointInferenceNetwork
6063
from bayesflow.scores import MultivariateNormalScore
6164

@@ -64,6 +67,7 @@ def point_inference_network_with_multiple_parametric_scores():
6467
mvn1=MultivariateNormalScore(),
6568
mvn2=MultivariateNormalScore(),
6669
),
70+
metrics=[RootMeanSquaredError(name="rmse")],
6771
)
6872

6973

@@ -181,9 +185,10 @@ def validation_dataset(batch_size, adapter, simulator):
181185

182186
@pytest.fixture()
183187
def mean_std_summary_network():
188+
from bayesflow.metrics import MaximumMeanDiscrepancy
184189
from tests.utils import MeanStdSummaryNetwork
185190

186-
return MeanStdSummaryNetwork()
191+
return MeanStdSummaryNetwork(metrics=[MaximumMeanDiscrepancy("mmd")])
187192

188193

189194
@pytest.fixture(params=["continuous_approximator", "point_approximator", "model_comparison_approximator"])

0 commit comments

Comments
 (0)