Skip to content

Commit fea5f89

Browse files
Merge pull request #202 from rusty-electron/test-summary-networks
add tests for summary networks [WIP]
2 parents a822d01 + 17ece6d commit fea5f89

File tree

7 files changed

+56
-38
lines changed

7 files changed

+56
-38
lines changed

bayesflow/networks/transformers/fusion_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __init__(
120120
raise ValueError("Argument `template_dim` should be in ['lstm', 'gru']")
121121

122122
self.output_projector = keras.layers.Dense(summary_dim)
123+
self.summary_dim = summary_dim
123124

124125
def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor:
125126
"""Compresses the input sequence into a summary vector of size `summary_dim`.

bayesflow/networks/transformers/set_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(
125125
)
126126
self.pooling_by_attention = PoolingByMultiHeadAttention(**(global_attention_settings | pooling_settings))
127127
self.output_projector = keras.layers.Dense(summary_dim)
128+
self.summary_dim = summary_dim
128129

129130
def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
130131
"""Compresses the input sequence into a summary vector of size `summary_dim`.

bayesflow/networks/transformers/time_series_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
# Pooling will be applied as a final step to the abstract representations obtained from set attention
104104
self.pooling = keras.layers.GlobalAvgPool1D()
105105
self.output_projector = keras.layers.Dense(summary_dim)
106+
self.summary_dim = summary_dim
106107

107108
self.time_axis = time_axis
108109

tests/conftest.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,9 @@ def conditions_size(request):
3232
return request.param
3333

3434

35-
@pytest.fixture(scope="function")
36-
def coupling_flow():
37-
from bayesflow.networks import CouplingFlow
38-
39-
return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(widths=(32, 32)))
35+
@pytest.fixture(params=[1, 4], scope="session")
36+
def summary_dim(request):
37+
return request.param
4038

4139

4240
@pytest.fixture(params=["two_moons"], scope="session")
@@ -49,16 +47,6 @@ def feature_size(request):
4947
return request.param
5048

5149

52-
@pytest.fixture(params=["coupling_flow"], scope="function")
53-
def inference_network(request):
54-
return request.getfixturevalue(request.param)
55-
56-
57-
@pytest.fixture(params=["inference_network", "summary_network"], scope="function")
58-
def network(request):
59-
return request.getfixturevalue(request.param)
60-
61-
6250
@pytest.fixture(scope="session")
6351
def random_conditions(batch_size, conditions_size):
6452
if conditions_size is None:
@@ -94,13 +82,6 @@ def simulator(request):
9482
return request.getfixturevalue(request.param)
9583

9684

97-
@pytest.fixture(params=[None], scope="function")
98-
def summary_network(request):
99-
if request.param is None:
100-
return None
101-
return request.getfixturevalue(request.param)
102-
103-
10485
@pytest.fixture(scope="session")
10586
def training_dataset(simulator, batch_size):
10687
from bayesflow.datasets import OfflineDataset

tests/test_networks/conftest.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11
import pytest
22

33

4-
@pytest.fixture()
5-
def deep_set():
6-
from bayesflow.networks import DeepSet
7-
8-
return DeepSet()
9-
10-
114
# For the serialization tests, we want to test passing str and type.
125
# For all other tests, this is not necessary and would double test time.
136
# Therefore, below we specify two variants of each network, one without and
@@ -79,23 +72,29 @@ def inference_network_subnet(request):
7972
return request.getfixturevalue(request.param)
8073

8174

82-
@pytest.fixture()
83-
def lst_net():
75+
@pytest.fixture(scope="function")
76+
def lst_net(summary_dim):
8477
from bayesflow.networks import LSTNet
8578

86-
return LSTNet()
79+
return LSTNet(summary_dim=summary_dim)
8780

8881

89-
@pytest.fixture()
90-
def set_transformer():
82+
@pytest.fixture(scope="function")
83+
def set_transformer(summary_dim):
9184
from bayesflow.networks import SetTransformer
9285

93-
return SetTransformer()
86+
return SetTransformer(summary_dim=summary_dim)
87+
88+
89+
@pytest.fixture(scope="function")
90+
def deep_set(summary_dim):
91+
from bayesflow.networks import DeepSet
92+
93+
return DeepSet(summary_dim=summary_dim)
9494

9595

96-
@pytest.fixture(params=[None, "deep_set", "lst_net", "set_transformer"])
97-
def summary_network(request):
96+
@pytest.fixture(params=[None, "lst_net", "set_transformer", "deep_set"], scope="function")
97+
def summary_network(request, summary_dim):
9898
if request.param is None:
9999
return None
100-
101100
return request.getfixturevalue(request.param)

tests/test_networks/test_summary_networks.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,31 @@ def test_save_and_load(tmp_path, summary_network, random_set):
7979
loaded = keras.saving.load_model(tmp_path / "model.keras")
8080

8181
assert_layers_equal(summary_network, loaded)
82+
83+
84+
@pytest.mark.parametrize("stage", ["training", "validation"])
85+
def test_compute_metrics(stage, summary_network, random_set):
86+
if summary_network is None:
87+
pytest.skip()
88+
89+
summary_network.build(keras.ops.shape(random_set))
90+
91+
metrics = summary_network.compute_metrics(random_set, stage=stage)
92+
93+
assert "outputs" in metrics
94+
95+
# check that the batch dimension is preserved
96+
assert keras.ops.shape(metrics["outputs"])[0] == keras.ops.shape(random_set)[0]
97+
98+
# check summary dimension
99+
summary_dim = summary_network.summary_dim
100+
assert keras.ops.shape(metrics["outputs"])[-1] == summary_dim
101+
102+
if summary_network.base_distribution is not None:
103+
assert "loss" in metrics
104+
assert keras.ops.shape(metrics["loss"]) == ()
105+
106+
if stage != "training":
107+
for metric in summary_network.metrics:
108+
assert metric.name in metrics
109+
assert keras.ops.shape(metrics[metric.name]) == ()

tests/test_two_moons/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
import pytest
22

33

4+
@pytest.fixture()
5+
def inference_network():
6+
from bayesflow.networks import CouplingFlow
7+
8+
return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(widths=(32, 32)))
9+
10+
411
@pytest.fixture()
512
def approximator(adapter, inference_network):
613
from bayesflow import ContinuousApproximator

0 commit comments

Comments
 (0)