Skip to content

Commit 8ed5dab

Browse files
committed
allow parametrization of kwargs
1 parent e6e89d7 commit 8ed5dab

File tree

26 files changed

+561
-591
lines changed

26 files changed

+561
-591
lines changed

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
BACKENDS = ["jax", "numpy", "tensorflow", "torch"]
77

88

9+
def pytest_addoption(parser):
10+
parser.addoption("--mode", choices=["save", "load"])
11+
parser.addoption("--commit", type=str)
12+
parser.addoption("--from", type=str, required=False, dest="from_")
13+
14+
915
def pytest_runtest_setup(item):
1016
"""Skips backends by test markers. Unmarked tests are treated as backend-agnostic"""
1117
backend = keras.backend.backend()

tests/test_compatibility/conftest.py

Lines changed: 164 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,174 @@ def data_dir(request, commit, from_commit, tmp_path_factory):
3636
return Path(tmp_path_factory.mktemp("_compatibility_data"))
3737

3838

39+
# reduce number of test configurations
40+
@pytest.fixture(params=[None, 3])
41+
def conditions_size(request):
42+
return request.param
43+
44+
45+
@pytest.fixture(params=[1, 2])
46+
def summary_dim(request):
47+
return request.param
48+
49+
50+
@pytest.fixture(params=[4])
51+
def feature_size(request):
52+
return request.param
53+
54+
55+
# Generic fixtures for use as input to the tested classes.
56+
# The classes to test are constructed in the respective subdirectories, to allow for more thorough configuation.
57+
@pytest.fixture(params=[None, "all"])
58+
def standardize(request):
59+
return request.param
60+
61+
62+
@pytest.fixture()
63+
def adapter(request):
64+
import bayesflow as bf
65+
66+
match request.param:
67+
case "summary":
68+
return bf.Adapter.create_default("parameters").rename("observables", "summary_variables")
69+
case "direct":
70+
return bf.Adapter.create_default("parameters").rename("observables", "direct_conditions")
71+
case "default":
72+
return bf.Adapter.create_default("parameters")
73+
case "empty":
74+
return bf.Adapter()
75+
case None:
76+
return None
77+
case _:
78+
raise ValueError(f"Invalid request parameter for adapter: {request.param}")
79+
80+
81+
@pytest.fixture(params=["coupling_flow", "flow_matching"])
82+
def inference_network(request):
83+
match request.param:
84+
case "coupling_flow":
85+
from bayesflow.networks import CouplingFlow
86+
87+
return CouplingFlow(depth=2)
88+
89+
case "flow_matching":
90+
from bayesflow.networks import FlowMatching
91+
92+
return FlowMatching(subnet_kwargs=dict(widths=(32, 32)), use_optimal_transport=False)
93+
94+
case None:
95+
return None
96+
97+
case _:
98+
raise ValueError(f"Invalid request parameter for inference_network: {request.param}")
99+
100+
101+
@pytest.fixture(params=["time_series_transformer", "fusion_transformer", "time_series_network", "custom"])
102+
def summary_network(request):
103+
match request.param:
104+
case "time_series_transformer":
105+
from bayesflow.networks import TimeSeriesTransformer
106+
107+
return TimeSeriesTransformer(embed_dims=(8, 8), mlp_widths=(16, 8), mlp_depths=(1, 1))
108+
109+
case "fusion_transformer":
110+
from bayesflow.networks import FusionTransformer
111+
112+
return FusionTransformer(
113+
embed_dims=(8, 8), mlp_widths=(8, 16), mlp_depths=(2, 1), template_dim=8, bidirectional=False
114+
)
115+
116+
case "time_series_network":
117+
from bayesflow.networks import TimeSeriesNetwork
118+
119+
return TimeSeriesNetwork(filters=4, skip_steps=2)
120+
121+
case "deep_set":
122+
from bayesflow.networks import DeepSet
123+
124+
return DeepSet(summary_dim=2, depth=1)
125+
126+
case "custom":
127+
from bayesflow.networks import SummaryNetwork
128+
from bayesflow.utils.serialization import serializable
129+
import keras
130+
131+
@serializable("test", disable_module_check=True)
132+
class Custom(SummaryNetwork):
133+
def __init__(self, **kwargs):
134+
super().__init__(**kwargs)
135+
self.inner = keras.Sequential([keras.layers.LSTM(8), keras.layers.Dense(4)])
136+
137+
def call(self, x, **kwargs):
138+
return self.inner(x, training=kwargs.get("stage") == "training")
139+
140+
return Custom()
141+
142+
case "flatten":
143+
# very simple summary network for fast training
144+
from bayesflow.networks import SummaryNetwork
145+
from bayesflow.utils.serialization import serializable
146+
import keras
147+
148+
@serializable("test", disable_module_check=True)
149+
class FlattenSummaryNetwork(SummaryNetwork):
150+
def __init__(self, **kwargs):
151+
super().__init__(**kwargs)
152+
self.inner = keras.layers.Flatten()
153+
154+
def call(self, x, **kwargs):
155+
return self.inner(x, training=kwargs.get("stage") == "training")
156+
157+
return FlattenSummaryNetwork()
158+
159+
case "fusion_network":
160+
from bayesflow.networks import FusionNetwork, DeepSet
161+
162+
return FusionNetwork({"a": DeepSet(), "b": keras.layers.Flatten()}, head=keras.layers.Dense(2))
163+
case None:
164+
return None
165+
case _:
166+
raise ValueError(f"Invalid request parameter for summary_network: {request.param}")
167+
168+
39169
@pytest.fixture(params=["sir", "fusion"])
40170
def simulator(request):
41-
if request.param == "sir":
42-
from bayesflow.simulators import SIR
171+
match request.param:
172+
case "sir":
173+
from bayesflow.simulators import SIR
174+
175+
return SIR()
176+
case "lotka_volterra":
177+
from bayesflow.simulators import LotkaVolterra
178+
179+
return LotkaVolterra()
180+
181+
case "two_moons":
182+
from bayesflow.simulators import TwoMoons
183+
184+
return TwoMoons()
185+
case "normal":
186+
from tests.utils.normal_simulator import NormalSimulator
43187

44-
return SIR()
45-
elif request.param == "fusion":
46-
from bayesflow.simulators import Simulator
47-
from bayesflow.types import Shape, Tensor
48-
from bayesflow.utils.decorators import allow_batch_size
49-
import numpy as np
188+
return NormalSimulator()
189+
case "fusion":
190+
from bayesflow.simulators import Simulator
191+
from bayesflow.types import Shape, Tensor
192+
from bayesflow.utils.decorators import allow_batch_size
193+
import numpy as np
50194

51-
class FusionSimulator(Simulator):
52-
@allow_batch_size
53-
def sample(self, batch_shape: Shape, num_observations: int = 4) -> dict[str, Tensor]:
54-
mean = np.random.normal(0.0, 0.1, size=batch_shape + (2,))
55-
noise = np.random.standard_normal(batch_shape + (num_observations, 2))
195+
class FusionSimulator(Simulator):
196+
@allow_batch_size
197+
def sample(self, batch_shape: Shape, num_observations: int = 4) -> dict[str, Tensor]:
198+
mean = np.random.normal(0.0, 0.1, size=batch_shape + (2,))
199+
noise = np.random.standard_normal(batch_shape + (num_observations, 2))
56200

57-
x = mean[:, None] + noise
201+
x = mean[:, None] + noise
58202

59-
return dict(mean=mean, a=x, b=x)
203+
return dict(mean=mean, a=x, b=x)
60204

61-
return FusionSimulator()
205+
return FusionSimulator()
206+
case None:
207+
return None
208+
case _:
209+
raise ValueError(f"Invalid request parameter for simulator: {request.param}")

tests/test_compatibility/test_adapters/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def serializable_fn(x):
3737
.scale("x", by=[-1, 2])
3838
.shift("x", by=2)
3939
.split("key_to_split", into=["split_1", "split_2"])
40-
.standardize(exclude=["t1", "t2", "o1"])
4140
.drop("d1")
4241
.one_hot("o1", 10)
4342
.keep(["x", "y", "z1", "p1", "p2", "s1", "s2", "s3", "t1", "t2", "o1", "split_1", "split_2"])
Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +0,0 @@
1-
import pytest
2-
3-
4-
@pytest.fixture()
5-
def batch_size():
6-
return 8
7-
8-
9-
@pytest.fixture()
10-
def num_samples():
11-
return 100
12-
13-
14-
@pytest.fixture()
15-
def adapter():
16-
import bayesflow as bf
17-
18-
return bf.Adapter.create_default("parameters").rename("observables", "summary_variables")
19-
20-
21-
@pytest.fixture(params=["coupling_flow", "flow_matching"])
22-
def inference_network(request):
23-
if request.param == "coupling_flow":
24-
from bayesflow.networks import CouplingFlow
25-
26-
return CouplingFlow(depth=2)
27-
28-
elif request.param == "flow_matching":
29-
from bayesflow.networks import FlowMatching
30-
31-
return FlowMatching(subnet_kwargs=dict(widths=(32, 32)), use_optimal_transport=False)
32-
33-
34-
@pytest.fixture(params=["time_series_transformer", "fusion_transformer", "time_series_network", "custom"])
35-
def summary_network(request):
36-
if request.param == "time_series_transformer":
37-
from bayesflow.networks import TimeSeriesTransformer
38-
39-
return TimeSeriesTransformer(embed_dims=(8, 8), mlp_widths=(16, 8), mlp_depths=(1, 1))
40-
41-
elif request.param == "fusion_transformer":
42-
from bayesflow.networks import FusionTransformer
43-
44-
return FusionTransformer(
45-
embed_dims=(8, 8), mlp_widths=(8, 16), mlp_depths=(2, 1), template_dim=8, bidirectional=False
46-
)
47-
48-
elif request.param == "time_series_network":
49-
from bayesflow.networks import TimeSeriesNetwork
50-
51-
return TimeSeriesNetwork(filters=4, skip_steps=2)
52-
53-
elif request.param == "custom":
54-
from bayesflow.networks import SummaryNetwork
55-
from bayesflow.utils.serialization import serializable
56-
import keras
57-
58-
@serializable("test", disable_module_check=True)
59-
class Custom(SummaryNetwork):
60-
def __init__(self, **kwargs):
61-
super().__init__(**kwargs)
62-
self.inner = keras.Sequential([keras.layers.LSTM(8), keras.layers.Dense(4)])
63-
64-
def call(self, x, **kwargs):
65-
return self.inner(x, training=kwargs.get("stage") == "training")
66-
67-
return Custom()
68-
69-
elif request.param == "fusion_network":
70-
from bayesflow.networks import FusionNetwork, DeepSet
71-
72-
return FusionNetwork({"a": DeepSet(), "b": keras.layers.Flatten()}, head=keras.layers.Dense(2))
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import pytest
2+
3+
4+
@pytest.fixture
5+
def approximator(adapter, inference_network, summary_network, standardize):
6+
from bayesflow.approximators import ContinuousApproximator
7+
8+
return ContinuousApproximator(
9+
adapter=adapter, inference_network=inference_network, summary_network=summary_network, standardize=standardize
10+
)

tests/test_compatibility/test_approximators/test_continuous_approximator/test_continuous_approximator.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
import keras
55

66

7-
@pytest.mark.parametrize("inference_network", ["coupling_flow", "flow_matching"], indirect=True)
7+
@pytest.mark.parametrize("inference_network", ["coupling_flow"], indirect=True)
88
@pytest.mark.parametrize(
9-
"summary_network,simulator,adapter",
9+
"summary_network,simulator,adapter,standardize",
1010
[
11-
["time_series_transformer", "sir", None],
12-
["fusion_transformer", "sir", None],
11+
["deep_set", "sir", "summary", ["summary_variables", "inference_variables"]], # use deep_set for speed
12+
[None, "two_moons", "direct", "all"],
13+
[None, "two_moons", "direct", None],
1314
],
1415
indirect=True,
1516
)
@@ -21,16 +22,9 @@ class TestContinuousApproximator(SaveLoadTest):
2122
}
2223

2324
@pytest.fixture()
24-
def setup(self, filepaths, mode, inference_network, summary_network, simulator, adapter):
25+
def setup(self, filepaths, mode, approximator, adapter, inference_network, summary_network, standardize, simulator):
2526
if mode == "save":
26-
import bayesflow as bf
27-
28-
approximator = bf.approximators.ContinuousApproximator(
29-
adapter=adapter,
30-
inference_network=inference_network,
31-
summary_network=summary_network,
32-
)
33-
approximator.compile("adamw")
27+
approximator.compile("adamw", run_eagerly=False)
3428
approximator.fit(simulator=simulator, epochs=1, batch_size=8, num_batches=2, verbose=0)
3529
keras.saving.save_model(approximator, filepaths["approximator"])
3630

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,6 @@ def adapter():
4848
)
4949

5050

51-
@pytest.fixture
52-
def summary_network():
53-
from bayesflow.networks import DeepSet
54-
55-
return DeepSet(summary_dim=2, depth=1)
56-
57-
5851
@pytest.fixture
5952
def classifier_network():
6053
from bayesflow.networks import MLP
@@ -71,7 +64,7 @@ def approximator(adapter, classifier_network, summary_network, simulator, standa
7164
classifier_network=classifier_network,
7265
adapter=adapter,
7366
summary_network=summary_network,
74-
# standardize=standardize,
67+
standardize=standardize,
7568
)
7669

7770

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import keras
55

66

7+
@pytest.mark.parametrize("summary_network", ["deep_set"], indirect=True)
78
class TestModelComparisonApproximator(SaveLoadTest):
89
filenames = {
910
"approximator": "approximator.keras",
@@ -12,7 +13,7 @@ class TestModelComparisonApproximator(SaveLoadTest):
1213
}
1314

1415
@pytest.fixture()
15-
def setup(self, filepaths, mode, simulator, approximator):
16+
def setup(self, filepaths, mode, simulator, approximator, classifier_network, summary_network):
1617
if mode == "save":
1718
approximator.compile("adamw")
1819
approximator.fit(

0 commit comments

Comments
 (0)