Skip to content

Commit e60a25e

Browse files
committed
Tests for ApproximatorEnsemble
1 parent 4d3130b commit e60a25e

File tree

7 files changed

+287
-0
lines changed

7 files changed

+287
-0
lines changed

tests/test_approximators/test_approximator_ensemble/__init__.py

Whitespace-only changes.
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import pytest
2+
import numpy as np
3+
from tests.utils import check_combination_simulator_adapter
4+
5+
6+
@pytest.fixture()
7+
def train_dataset_for_ensemble(batch_size, adapter, simulator):
8+
check_combination_simulator_adapter(simulator, adapter)
9+
10+
from bayesflow import OfflineEnsembleDataset
11+
12+
num_batches = 4
13+
data = simulator.sample((num_batches * batch_size,))
14+
return OfflineEnsembleDataset(
15+
num_ensemble=2, data=data, adapter=adapter, batch_size=batch_size, workers=4, max_queue_size=num_batches
16+
)
17+
18+
19+
@pytest.fixture()
20+
def continuous_and_point_approximator_ensemble(
21+
continuous_approximator, point_approximator_with_single_parametric_score
22+
):
23+
from bayesflow import ApproximatorEnsemble
24+
25+
return ApproximatorEnsemble(
26+
dict(cont_approx=continuous_approximator, point_approx=point_approximator_with_single_parametric_score)
27+
)
28+
29+
30+
@pytest.fixture(
31+
params=[
32+
"continuous_and_point_approximator_ensemble",
33+
],
34+
scope="function",
35+
)
36+
def continuous_approximator_ensemble(request):
37+
return request.getfixturevalue(request.param)
38+
39+
40+
@pytest.fixture
41+
def model_comparison_simulator():
42+
from bayesflow import make_simulator
43+
from bayesflow.simulators import ModelComparisonSimulator
44+
45+
def context(batch_shape, n=None):
46+
if n is None:
47+
n = np.random.randint(2, 5)
48+
return dict(n=n)
49+
50+
def prior_null():
51+
return dict(mu=0.0)
52+
53+
def prior_alternative():
54+
mu = np.random.normal(loc=0, scale=1)
55+
return dict(mu=mu)
56+
57+
def likelihood(n, mu):
58+
x = np.random.normal(loc=mu, scale=1, size=n)
59+
return dict(x=x)
60+
61+
simulator_null = make_simulator([prior_null, likelihood])
62+
simulator_alternative = make_simulator([prior_alternative, likelihood])
63+
return ModelComparisonSimulator(
64+
simulators=[simulator_null, simulator_alternative],
65+
use_mixed_batches=True,
66+
shared_simulator=context,
67+
)
68+
69+
70+
@pytest.fixture()
71+
def model_comparison_train_dataset_for_ensemble(batch_size, adapter, simulator):
72+
check_combination_simulator_adapter(simulator, adapter)
73+
74+
from bayesflow import OfflineEnsembleDataset
75+
76+
num_batches = 4
77+
data = simulator.sample((num_batches * batch_size,))
78+
return OfflineEnsembleDataset(
79+
num_ensemble=2, data=data, adapter=adapter, batch_size=batch_size, workers=4, max_queue_size=num_batches
80+
)
81+
82+
83+
@pytest.fixture
84+
def model_comparison_adapter():
85+
from bayesflow import Adapter
86+
87+
return (
88+
Adapter()
89+
.sqrt("n")
90+
.broadcast("n", to="x")
91+
.as_set("x")
92+
.rename("n", "classifier_conditions")
93+
.rename("x", "summary_variables")
94+
.drop("mu")
95+
.convert_dtype("float64", "float32")
96+
)
97+
98+
99+
@pytest.fixture()
100+
def basic_model_comparison_ensemble(model_comparison_adapter):
101+
from bayesflow.approximators import ModelComparisonApproximator, ApproximatorEnsemble
102+
from bayesflow.networks import DeepSet, MLP
103+
104+
classifier_network = MLP(widths=[32, 32])
105+
106+
summary_network = DeepSet(summary_dim=2, depth=1)
107+
108+
approx_1 = ModelComparisonApproximator(
109+
num_models=2,
110+
classifier_network=classifier_network,
111+
adapter=model_comparison_adapter,
112+
summary_network=summary_network,
113+
)
114+
approx_2 = ModelComparisonApproximator(
115+
num_models=2,
116+
classifier_network=classifier_network,
117+
adapter=model_comparison_adapter,
118+
summary_network=summary_network,
119+
)
120+
121+
return ApproximatorEnsemble(dict(approx_1=approx_1, approx_2=approx_2))
122+
123+
124+
@pytest.fixture(
125+
params=[
126+
"basic_model_comparison_ensemble",
127+
],
128+
scope="function",
129+
)
130+
def model_comparison_approximator_ensemble(request):
131+
return request.getfixturevalue(request.param)
132+
133+
134+
@pytest.fixture(
135+
params=[
136+
"continuous_and_point_approximator_ensemble",
137+
"basic_model_comparison_ensemble",
138+
],
139+
scope="function",
140+
)
141+
def approximator_ensemble(request):
142+
return request.getfixturevalue(request.param)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import keras
2+
from tests.utils import check_combination_simulator_adapter
3+
4+
5+
def test_build_continuous(continuous_approximator_ensemble, simulator, batch_size, adapter):
6+
check_combination_simulator_adapter(simulator, adapter)
7+
8+
num_batches = 4
9+
data = simulator.sample((num_batches * batch_size,))
10+
11+
batch = adapter(data)
12+
batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch)
13+
batch_shapes = keras.tree.map_structure(keras.ops.shape, batch)
14+
print(batch_shapes)
15+
continuous_approximator_ensemble.build(batch_shapes)
16+
17+
for member in continuous_approximator_ensemble.approximators.values():
18+
for layer in member.standardize_layers.values():
19+
assert layer.built
20+
for count in layer.count:
21+
assert count == 0.0
22+
23+
24+
def test_build_model_comparison(
25+
model_comparison_approximator_ensemble, model_comparison_simulator, batch_size, model_comparison_adapter
26+
):
27+
check_combination_simulator_adapter(model_comparison_simulator, model_comparison_adapter)
28+
29+
num_batches = 4
30+
data = model_comparison_simulator.sample((num_batches * batch_size,))
31+
32+
batch = model_comparison_adapter(data)
33+
batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch)
34+
batch_shapes = keras.tree.map_structure(keras.ops.shape, batch)
35+
model_comparison_approximator_ensemble.build(batch_shapes)
36+
37+
for member in model_comparison_approximator_ensemble.approximators.values():
38+
for layer in member.standardize_layers.values():
39+
assert layer.built
40+
for count in layer.count:
41+
assert count == 0.0
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import keras
2+
from tests.utils import check_combination_simulator_adapter
3+
4+
5+
def test_approximator_estimate(continuous_approximator_ensemble, simulator, batch_size, adapter):
6+
check_combination_simulator_adapter(simulator, adapter)
7+
8+
num_batches = 4
9+
data = simulator.sample((num_batches * batch_size,))
10+
11+
batch = adapter(data)
12+
batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch)
13+
batch_shapes = keras.tree.map_structure(keras.ops.shape, batch)
14+
continuous_approximator_ensemble.build(batch_shapes)
15+
16+
estimates = continuous_approximator_ensemble.estimate(data)
17+
18+
assert isinstance(estimates, dict)
19+
20+
for estimates_value in estimates.values():
21+
assert isinstance(estimates_value, dict)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import keras
2+
import numpy as np
3+
from tests.utils import check_combination_simulator_adapter
4+
5+
6+
def test_approximator_log_prob(continuous_approximator_ensemble, simulator, batch_size, adapter):
7+
check_combination_simulator_adapter(simulator, adapter)
8+
9+
num_batches = 4
10+
data = simulator.sample((num_batches * batch_size,))
11+
12+
batch = adapter(data)
13+
batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch)
14+
batch_shapes = keras.tree.map_structure(keras.ops.shape, batch)
15+
continuous_approximator_ensemble.build(batch_shapes)
16+
17+
log_prob = continuous_approximator_ensemble.log_prob(data=data)
18+
assert isinstance(log_prob, dict)
19+
20+
for log_prob_value in log_prob.values():
21+
assert isinstance(log_prob_value, (np.ndarray, dict))
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import keras
2+
import numpy as np
3+
from tests.utils import check_combination_simulator_adapter
4+
5+
6+
# def test_predict(model_comparison_approximator_ensemble, model_comparison_train_dataset_for_ensemble, simulator):
7+
# approximator_ensemble = model_comparison_approximator_ensemble
8+
# data_shapes = keras.tree.map_structure(keras.ops.shape, model_comparison_train_dataset_for_ensemble[0])
9+
# approximator_ensemble.build(data_shapes)
10+
# approximator_ensemble.compute_metrics(**model_comparison_train_dataset_for_ensemble[0])
11+
#
12+
# num_conditions = 2
13+
# conditions = simulator.sample(num_conditions)
14+
# predictions = approximator_ensemble.predict(conditions=conditions)
15+
#
16+
# for predictions_value in predictions.values():
17+
# assert isinstance(predictions_value, np.ndarray)
18+
# assert predictions_value.shape[0] == num_conditions
19+
20+
21+
def test_predict_model_comparison(
22+
model_comparison_approximator_ensemble, model_comparison_simulator, batch_size, model_comparison_adapter
23+
):
24+
check_combination_simulator_adapter(model_comparison_simulator, model_comparison_adapter)
25+
26+
num_batches = 4
27+
data = model_comparison_simulator.sample((num_batches * batch_size,))
28+
29+
batch = model_comparison_adapter(data)
30+
batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch)
31+
batch_shapes = keras.tree.map_structure(keras.ops.shape, batch)
32+
model_comparison_approximator_ensemble.build(batch_shapes)
33+
34+
num_conditions = 2
35+
conditions = model_comparison_simulator.sample(num_conditions)
36+
print(conditions)
37+
predictions = model_comparison_approximator_ensemble.predict(conditions=conditions)
38+
39+
for predictions_value in predictions.values():
40+
assert isinstance(predictions_value, np.ndarray)
41+
assert predictions_value.shape[0] == num_conditions
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import keras
2+
from tests.utils import check_combination_simulator_adapter
3+
4+
5+
def test_approximator_sample(continuous_approximator_ensemble, simulator, batch_size, adapter):
6+
check_combination_simulator_adapter(simulator, adapter)
7+
8+
num_batches = 4
9+
data = simulator.sample((num_batches * batch_size,))
10+
11+
batch = adapter(data)
12+
batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch)
13+
batch_shapes = keras.tree.map_structure(keras.ops.shape, batch)
14+
continuous_approximator_ensemble.build(batch_shapes)
15+
16+
samples = continuous_approximator_ensemble.sample(num_samples=2, conditions=data)
17+
18+
assert isinstance(samples, dict)
19+
20+
for samples_value in samples.values():
21+
assert isinstance(samples_value, dict)

0 commit comments

Comments
 (0)