Skip to content

Commit f6a70b5

Browse files
committed
Improve workflow tests with multiple summary nets / approximators
1 parent 25f5c64 commit f6a70b5

File tree

2 files changed

+63
-13
lines changed

2 files changed

+63
-13
lines changed

tests/test_workflows/conftest.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,52 @@
11
import pytest
22

3+
import keras
34

4-
@pytest.fixture()
5-
def inference_network():
6-
from bayesflow.networks import CouplingFlow
5+
from bayesflow.utils.serialization import serializable
76

8-
return CouplingFlow(depth=2)
97

8+
@pytest.fixture(params=["coupling_flow", "flow_matching"])
9+
def inference_network(request):
10+
if request.param == "coupling_flow":
11+
from bayesflow.networks import CouplingFlow
1012

11-
@pytest.fixture()
12-
def summary_network():
13-
from bayesflow.networks import TimeSeriesTransformer
13+
return CouplingFlow(depth=2)
1414

15-
return TimeSeriesTransformer(embed_dims=(8, 8), mlp_widths=(32, 32), mlp_depths=(1, 1))
15+
elif request.param == "flow_matching":
16+
from bayesflow.networks import FlowMatching
17+
18+
return FlowMatching(subnet_kwargs=dict(widths=(32, 32)), use_optimal_transport=False)
19+
20+
21+
@pytest.fixture(params=["time_series_transformer", "fusion_transformer", "time_series_network", "custom"])
22+
def summary_network(request):
23+
if request.param == "time_series_transformer":
24+
from bayesflow.networks import TimeSeriesTransformer
25+
26+
return TimeSeriesTransformer(embed_dims=(8, 8), mlp_widths=(16, 8), mlp_depths=(1, 1))
27+
28+
elif request.param == "fusion_transformer":
29+
from bayesflow.networks import FusionTransformer
30+
31+
return FusionTransformer(
32+
embed_dims=(8, 8), mlp_widths=(8, 16), mlp_depths=(2, 1), template_dim=8, bidirectional=False
33+
)
34+
35+
elif request.param == "time_series_network":
36+
from bayesflow.networks import TimeSeriesNetwork
37+
38+
return TimeSeriesNetwork(filters=4, skip_steps=2)
39+
40+
elif request.param == "custom":
41+
from bayesflow.networks import SummaryNetwork
42+
43+
@serializable
44+
class Custom(SummaryNetwork):
45+
def __init__(self, **kwargs):
46+
super().__init__(**kwargs)
47+
self.inner = keras.Sequential([keras.layers.LSTM(8), keras.layers.Dense(4)])
48+
49+
def call(self, x, **kwargs):
50+
return self.inner(x, training=kwargs.get("stage") == "training")
51+
52+
return Custom()
Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,34 @@
1+
import os
2+
3+
import keras
4+
15
import bayesflow as bf
26

37

4-
def test_basic_workflow(inference_network, summary_network):
8+
def test_basic_workflow(tmp_path, inference_network, summary_network):
59
workflow = bf.BasicWorkflow(
610
inference_network=inference_network,
711
summary_network=summary_network,
812
inference_variables=["parameters"],
913
summary_variables=["observables"],
1014
simulator=bf.simulators.SIR(),
15+
checkpoint_filepath=str(tmp_path),
1116
)
1217

13-
history = workflow.fit_online(epochs=2, batch_size=32, num_batches_per_epoch=2)
14-
plots = workflow.plot_default_diagnostics(test_data=50, num_samples=50)
15-
metrics = workflow.compute_default_diagnostics(test_data=50, num_samples=50, variable_names=["p1", "p2"])
18+
# Ensure metrics work fine
19+
history = workflow.fit_online(epochs=4, batch_size=8, num_batches_per_epoch=2, verbose=0)
20+
plots = workflow.plot_default_diagnostics(test_data=50, num_samples=25)
21+
metrics = workflow.compute_default_diagnostics(test_data=50, num_samples=25, variable_names=["p1", "p2"])
1622

1723
assert "loss" in list(history.history.keys())
18-
assert len(history.history["loss"]) == 2
24+
assert len(history.history["loss"]) == 4
1925
assert list(plots.keys()) == ["losses", "recovery", "calibration_ecdf", "z_score_contraction"]
2026
assert list(metrics.columns) == ["p1", "p2"]
2127
assert metrics.values.shape == (3, 2)
28+
29+
# Ensure saving and loading from workflow works fine
30+
loaded_approximator = keras.saving.load_model(os.path.join(str(tmp_path), "model.keras"))
31+
32+
# Get samples
33+
samples = loaded_approximator.sample(conditions=workflow.simulate(5), num_samples=3)
34+
assert samples["parameters"].shape == (5, 3, 2)

0 commit comments

Comments
 (0)