Skip to content

Commit ac0461a

Browse files
committed
add tests for model comparison approximator
1 parent 2c90547 commit ac0461a

File tree

3 files changed

+156
-0
lines changed

3 files changed

+156
-0
lines changed

tests/test_approximators/test_model_comparison_approximator/__init__.py

Whitespace-only changes.
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import pytest
2+
import numpy as np
3+
4+
5+
@pytest.fixture
6+
def simulator():
7+
from bayesflow import make_simulator
8+
from bayesflow.simulators import ModelComparisonSimulator
9+
10+
def context(batch_shape, n=None):
11+
if n is None:
12+
n = np.random.randint(2, 5)
13+
return dict(n=n)
14+
15+
def prior_null():
16+
return dict(mu=0.0)
17+
18+
def prior_alternative():
19+
mu = np.random.normal(loc=0, scale=1)
20+
return dict(mu=mu)
21+
22+
def likelihood(n, mu):
23+
x = np.random.normal(loc=mu, scale=1, size=n)
24+
return dict(x=x)
25+
26+
simulator_null = make_simulator([prior_null, likelihood])
27+
simulator_alternative = make_simulator([prior_alternative, likelihood])
28+
return ModelComparisonSimulator(
29+
simulators=[simulator_null, simulator_alternative],
30+
use_mixed_batches=True,
31+
shared_simulator=context,
32+
)
33+
34+
35+
@pytest.fixture
36+
def adapter():
37+
from bayesflow import Adapter
38+
39+
return (
40+
Adapter()
41+
.sqrt("n")
42+
.broadcast("n", to="x")
43+
.as_set("x")
44+
.rename("n", "classifier_conditions")
45+
.rename("x", "summary_variables")
46+
.drop("mu")
47+
.convert_dtype("float64", "float32")
48+
)
49+
50+
51+
@pytest.fixture
52+
def summary_network():
53+
from bayesflow.networks import DeepSet
54+
55+
return DeepSet(summary_dim=2, depth=1)
56+
57+
58+
@pytest.fixture
59+
def classifier_network():
60+
from bayesflow.networks import MLP
61+
62+
return MLP(widths=[32, 32])
63+
64+
65+
@pytest.fixture
66+
def approximator(adapter, classifier_network, summary_network, simulator):
67+
from bayesflow.approximators import ModelComparisonApproximator
68+
69+
return ModelComparisonApproximator(
70+
num_models=len(simulator.simulators),
71+
classifier_network=classifier_network,
72+
adapter=adapter,
73+
summary_network=summary_network,
74+
)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import keras
2+
import numpy as np
3+
import io
4+
from contextlib import redirect_stdout
5+
6+
from tests.utils import assert_models_equal
7+
8+
9+
def test_build(approximator, train_dataset):
10+
assert approximator.built is False
11+
12+
data_shapes = keras.tree.map_structure(keras.ops.shape, train_dataset[0])
13+
approximator.build(data_shapes)
14+
15+
assert approximator.built is True
16+
assert approximator.classifier_network.built is True
17+
if approximator.summary_network is not None:
18+
assert approximator.summary_network.built is True
19+
20+
21+
def test_build_adapter():
22+
from bayesflow.approximators import ModelComparisonApproximator
23+
24+
_ = ModelComparisonApproximator.build_adapter(
25+
num_models=2,
26+
classifier_conditions=["foo", "bar"],
27+
summary_variables=["observables"],
28+
model_index_name=["indices"],
29+
)
30+
31+
32+
def test_build_dataset(simulator, adapter):
33+
from bayesflow.approximators import ModelComparisonApproximator
34+
from bayesflow.datasets import OnlineDataset
35+
36+
dataset = ModelComparisonApproximator.build_dataset(
37+
simulator=simulator,
38+
memory_budget="20 KiB",
39+
num_batches=2,
40+
num_models=2,
41+
classifier_conditions="foo",
42+
summary_variables=["x1", "x2"],
43+
)
44+
assert isinstance(dataset, OnlineDataset)
45+
46+
47+
def test_fit(approximator, train_dataset, validation_dataset):
48+
approximator.compile(optimizer="AdamW")
49+
num_epochs = 1
50+
51+
# Capture ostream and train model
52+
with io.StringIO() as stream:
53+
with redirect_stdout(stream):
54+
approximator.fit(dataset=train_dataset, validation_data=validation_dataset, epochs=num_epochs)
55+
56+
output = stream.getvalue()
57+
# check that the loss is shown
58+
assert "loss" in output
59+
60+
61+
def test_save_and_load(tmp_path, approximator, train_dataset, validation_dataset):
62+
# to save, the model must be built
63+
data_shapes = keras.tree.map_structure(keras.ops.shape, train_dataset[0])
64+
approximator.build(data_shapes)
65+
approximator.compute_metrics(**train_dataset[0])
66+
67+
keras.saving.save_model(approximator, tmp_path / "model.keras")
68+
loaded = keras.saving.load_model(tmp_path / "model.keras")
69+
70+
assert_models_equal(approximator, loaded)
71+
72+
73+
def test_predict(approximator, train_dataset, simulator):
74+
data_shapes = keras.tree.map_structure(keras.ops.shape, train_dataset[0])
75+
approximator.build(data_shapes)
76+
approximator.compute_metrics(**train_dataset[0])
77+
78+
num_conditions = 2
79+
conditions = simulator.sample(num_conditions)
80+
output = approximator.predict(conditions=conditions)
81+
assert isinstance(output, np.ndarray)
82+
assert output.shape[0] == num_conditions

0 commit comments

Comments
 (0)