Skip to content

Commit 703ac2d

Browse files
committed
add test
1 parent a1e6ef0 commit 703ac2d

File tree

1 file changed

+46
-2
lines changed

1 file changed

+46
-2
lines changed

tests/test_networks/conftest.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pytest
22

3+
from bayesflow.networks.sequential import Sequential
34
from bayesflow.networks import MLP
5+
from bayesflow.utils.tensor_utils import concatenate_valid
46

57

68
@pytest.fixture()
@@ -15,6 +17,29 @@ def diffusion_model_edm_F():
1517
)
1618

1719

20+
class ConcatenateMLP(Sequential):
21+
def __init__(self, widths):
22+
super().__init__()
23+
self.mlp = MLP(widths)
24+
25+
def call(self, x, t, conditions=None, training=False):
26+
con = concatenate_valid([x, t, conditions], axis=-1)
27+
return self.mlp(con)
28+
29+
30+
@pytest.fixture()
31+
def diffusion_model_edm_F_subnet_concatenate():
32+
from bayesflow.networks import DiffusionModel
33+
34+
return DiffusionModel(
35+
subnet=ConcatenateMLP([8, 8]),
36+
integrate_kwargs={"method": "rk45", "steps": 250},
37+
noise_schedule="edm",
38+
prediction_type="F",
39+
concatenate_subnet_input=False,
40+
)
41+
42+
1843
@pytest.fixture()
1944
def diffusion_model_edm_velocity():
2045
from bayesflow.networks import DiffusionModel
@@ -85,13 +110,29 @@ def flow_matching():
85110
)
86111

87112

113+
@pytest.fixture()
114+
def flow_matching_subnet_concatenate():
115+
from bayesflow.networks import FlowMatching
116+
117+
return FlowMatching(
118+
subnet=ConcatenateMLP([8, 8]), integrate_kwargs={"method": "rk45", "steps": 100}, concatenate_subnet_input=False
119+
)
120+
121+
88122
@pytest.fixture()
89123
def consistency_model():
90124
from bayesflow.networks import ConsistencyModel
91125

92126
return ConsistencyModel(total_steps=100, subnet=MLP([8, 8]))
93127

94128

129+
@pytest.fixture()
130+
def consistency_model_subnet_concatenate():
131+
from bayesflow.networks import ConsistencyModel
132+
133+
return ConsistencyModel(total_steps=100, subnet=ConcatenateMLP([8, 8]), concatenate_subnet_input=False)
134+
135+
95136
@pytest.fixture()
96137
def affine_coupling_flow():
97138
from bayesflow.networks import CouplingFlow
@@ -189,14 +230,17 @@ def inference_network_subnet(request):
189230
"affine_coupling_flow",
190231
"spline_coupling_flow",
191232
"flow_matching",
233+
pytest.param("flow_matching_subnet_concatenate"),
192234
"free_form_flow",
193235
"consistency_model",
236+
pytest.param("consistency_model_subnet_concatenate"),
194237
pytest.param("diffusion_model_edm_F"),
238+
pytest.param("diffusion_model_edm_F_subnet_concatenate"),
195239
pytest.param(
196240
"diffusion_model_edm_noise",
197241
marks=[
198242
pytest.mark.slow,
199-
pytest.mark.skip("noise predicition not testable without prior training for numerical reasons."),
243+
pytest.mark.skip("noise prediction not testable without prior training for numerical reasons."),
200244
],
201245
),
202246
pytest.param("diffusion_model_cosine_velocity", marks=pytest.mark.slow),
@@ -211,7 +255,7 @@ def inference_network_subnet(request):
211255
"diffusion_model_cosine_noise",
212256
marks=[
213257
pytest.mark.slow,
214-
pytest.mark.skip("noise predicition not testable without prior training for numerical reasons."),
258+
pytest.mark.skip("noise prediction not testable without prior training for numerical reasons."),
215259
],
216260
),
217261
pytest.param(

0 commit comments

Comments
 (0)