Skip to content

Commit af3f19c

Browse files
committed
adjust signatures, extend tests
1 parent 1eda62c commit af3f19c

File tree

7 files changed

+24
-12
lines changed

7 files changed

+24
-12
lines changed

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,16 @@ def build_dataset(
106106
def compile(
107107
self,
108108
*args,
109-
classifier_metrics: Sequence[keras.Metric] = None,
110-
summary_metrics: Sequence[keras.Metric] = None,
111109
**kwargs,
112110
):
113-
if classifier_metrics:
111+
if "classifier_metrics" in kwargs:
114112
warnings.warn(
115113
"Supplying classifier metrics to the approximator is no longer supported. "
116114
"Please pass the metrics directly to the network using the metrics parameter.",
117115
DeprecationWarning,
118116
)
119117

120-
if summary_metrics:
118+
if "summary_metrics" in kwargs:
121119
warnings.warn(
122120
"Supplying summary metrics to the approximator is no longer supported. "
123121
"Please pass the metrics directly to the network using the metrics parameter.",
@@ -166,8 +164,10 @@ def compute_metrics(
166164
classifier_metrics |= {
167165
metric.name: metric(model_indices, predictions) for metric in self.classifier_network.metrics
168166
}
169-
170-
loss = classifier_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(()))
167+
if "loss" in summary_metrics:
168+
loss = classifier_metrics["loss"] + summary_metrics["loss"]
169+
else:
170+
loss = classifier_metrics.pop("loss")
171171

172172
classifier_metrics = {f"{key}/classifier_{key}": value for key, value in classifier_metrics.items()}
173173
summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()}

bayesflow/networks/inference_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
@serializable("bayesflow.networks")
1111
class InferenceNetwork(keras.Layer):
12-
def __init__(self, base_distribution: str = "normal", *, metrics: Sequence[keras.Metric] = None, **kwargs):
12+
def __init__(self, base_distribution: str = "normal", *, metrics: Sequence[keras.Metric] | None = None, **kwargs):
1313
self.custom_metrics = metrics
1414
super().__init__(**layer_kwargs(kwargs))
1515
self.base_distribution = find_distribution(base_distribution)

bayesflow/networks/summary_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
@serializable("bayesflow.networks")
1212
class SummaryNetwork(keras.Layer):
13-
def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Metric] = None, **kwargs):
13+
def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Metric] | None = None, **kwargs):
1414
self.custom_metrics = metrics
1515
super().__init__(**layer_kwargs(kwargs))
1616
self.base_distribution = find_distribution(base_distribution)

tests/test_approximators/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@ def summary_network():
2020
@pytest.fixture()
2121
def inference_network():
2222
from bayesflow.networks import CouplingFlow
23+
from bayesflow.metrics import RootMeanSquaredError
2324

24-
return CouplingFlow(subnet="mlp", depth=2, subnet_kwargs=dict(widths=(32, 32)))
25+
return CouplingFlow(
26+
subnet="mlp", depth=2, subnet_kwargs=dict(widths=(32, 32)), metrics=[RootMeanSquaredError(name="rmse")]
27+
)
2528

2629

2730
@pytest.fixture()

tests/test_approximators/test_fit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,7 @@ def test_loss_progress(approximator, train_dataset, validation_dataset):
4949
# check that the shown loss is not nan or zero
5050
assert re.search(r"\bnan\b", output) is None, "found nan in output"
5151
assert re.search(r"\bloss: 0\.0000e\+00\b", output) is None, "found zero loss in output"
52+
53+
# check that additional metric is present
54+
assert "val_rmse/inference_rmse" in output, "custom metric (RMSE) not shown"
55+
assert re.search(r"\bval_rmse/inference_rmse: \d+\.\d+", output) is not None, "custom metric not correctly shown"

tests/test_approximators/test_model_comparison_approximator/conftest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,17 @@ def adapter():
5151
@pytest.fixture
5252
def summary_network():
5353
from bayesflow.networks import DeepSet
54+
from bayesflow.metrics import RootMeanSquaredError
5455

55-
return DeepSet(summary_dim=2, depth=1)
56+
return DeepSet(summary_dim=2, depth=1, base_distribution="normal", metrics=[RootMeanSquaredError(name="rmse")])
5657

5758

5859
@pytest.fixture
5960
def classifier_network():
6061
from bayesflow.networks import MLP
62+
from keras.metrics import CategoricalAccuracy
6163

62-
return MLP(widths=[32, 32])
64+
return MLP(widths=[32, 32], metrics=[CategoricalAccuracy(name="categorical_accuracy")])
6365

6466

6567
@pytest.fixture

tests/test_approximators/test_model_comparison_approximator/test_model_comparison_approximator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ def test_fit(approximator, train_dataset, validation_dataset):
5555

5656
output = stream.getvalue()
5757
# check that the loss is shown
58-
assert "loss" in output
58+
assert "loss/summary_loss" in output
59+
assert "loss/classifier_loss" in output
60+
assert "val_categorical_accuracy/classifier_categorical_accuracy" in output
61+
assert "val_rmse/summary_rmse" in output
5962

6063

6164
def test_save_and_load(tmp_path, approximator, train_dataset, validation_dataset):

0 commit comments

Comments
 (0)