Skip to content

Commit 8402a3f

Browse files
committed
add diffusion models to inference network tests
- as we have six combinations, and the inference network test are quite compute heavy, I have marked the non-default ones as "slow" - the "noise" prediction setup produces too extreme values (due to the combination of the transformations and the outputs of an untrained network) to be tested with our current test suite. I have double checked the formulas, and also tested on two moons. The results did not indicate problems, so I decided to skip those tests for now. - for convenience, I added a marker so that the diffusion tests can be selected and be run together. We can remove it if it is not wanted here.
1 parent e380f5e commit 8402a3f

File tree

1 file changed

+84
-6
lines changed

1 file changed

+84
-6
lines changed

tests/test_networks/conftest.py

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,75 @@
44

55

66
@pytest.fixture()
7-
def diffusion_model():
7+
def diffusion_model_edm_F():
88
from bayesflow.experimental import DiffusionModel
99

1010
return DiffusionModel(
1111
subnet_kwargs={"widths": [64, 64]},
1212
integrate_kwargs={"method": "rk45", "steps": 100},
13+
noise_schedule="edm",
14+
prediction_type="F",
1315
)
1416

1517

1618
@pytest.fixture()
17-
def diffusion_model_subnet(subnet):
19+
def diffusion_model_edm_velocity():
1820
from bayesflow.experimental import DiffusionModel
1921

20-
return DiffusionModel(subnet=subnet)
22+
return DiffusionModel(
23+
subnet_kwargs={"widths": [64, 64]},
24+
integrate_kwargs={"method": "rk45", "steps": 100},
25+
noise_schedule="edm",
26+
prediction_type="velocity",
27+
)
28+
29+
30+
@pytest.fixture()
31+
def diffusion_model_edm_noise():
32+
from bayesflow.experimental import DiffusionModel
33+
34+
return DiffusionModel(
35+
subnet_kwargs={"widths": [64, 64]},
36+
integrate_kwargs={"method": "rk45", "steps": 100},
37+
noise_schedule="edm",
38+
prediction_type="noise",
39+
)
40+
41+
42+
@pytest.fixture()
43+
def diffusion_model_cosine_F():
44+
from bayesflow.experimental import DiffusionModel
45+
46+
return DiffusionModel(
47+
subnet_kwargs={"widths": [64, 64]},
48+
integrate_kwargs={"method": "rk45", "steps": 100},
49+
noise_schedule="cosine",
50+
prediction_type="F",
51+
)
52+
53+
54+
@pytest.fixture()
55+
def diffusion_model_cosine_velocity():
56+
from bayesflow.experimental import DiffusionModel
57+
58+
return DiffusionModel(
59+
subnet_kwargs={"widths": [64, 64]},
60+
integrate_kwargs={"method": "rk45", "steps": 100},
61+
noise_schedule="cosine",
62+
prediction_type="velocity",
63+
)
64+
65+
66+
@pytest.fixture()
67+
def diffusion_model_cosine_noise():
68+
from bayesflow.experimental import DiffusionModel
69+
70+
return DiffusionModel(
71+
subnet_kwargs={"widths": [64, 64]},
72+
integrate_kwargs={"method": "rk45", "steps": 100},
73+
noise_schedule="cosine",
74+
prediction_type="noise",
75+
)
2176

2277

2378
@pytest.fixture()
@@ -101,9 +156,14 @@ def typical_point_inference_network_subnet():
101156
"affine_coupling_flow",
102157
"spline_coupling_flow",
103158
"flow_matching",
104-
"diffusion_model",
105159
"free_form_flow",
106160
"consistency_model",
161+
pytest.param("diffusion_model_edm_F", marks=pytest.mark.diffusion_model),
162+
pytest.param("diffusion_model_edm_noise", marks=[pytest.mark.slow, pytest.mark.diffusion_model]),
163+
pytest.param("diffusion_model_cosine_velocity", marks=[pytest.mark.slow, pytest.mark.diffusion_model]),
164+
pytest.param("diffusion_model_cosine_F", marks=[pytest.mark.slow, pytest.mark.diffusion_model]),
165+
pytest.param("diffusion_model_cosine_noise", marks=[pytest.mark.slow, pytest.mark.diffusion_model]),
166+
pytest.param("diffusion_model_cosine_velocity", marks=[pytest.mark.slow, pytest.mark.diffusion_model]),
107167
],
108168
scope="function",
109169
)
@@ -116,7 +176,6 @@ def inference_network(request):
116176
"typical_point_inference_network_subnet",
117177
"coupling_flow_subnet",
118178
"flow_matching_subnet",
119-
"diffusion_model_subnet",
120179
"free_form_flow_subnet",
121180
],
122181
scope="function",
@@ -130,9 +189,28 @@ def inference_network_subnet(request):
130189
"affine_coupling_flow",
131190
"spline_coupling_flow",
132191
"flow_matching",
133-
"diffusion_model",
134192
"free_form_flow",
135193
"consistency_model",
194+
pytest.param("diffusion_model_edm_F", marks=pytest.mark.diffusion_model),
195+
pytest.param(
196+
"diffusion_model_edm_noise",
197+
marks=[
198+
pytest.mark.slow,
199+
pytest.mark.diffusion_model,
200+
pytest.mark.skip("noise predicition not testable without prior training for numerical reasons."),
201+
],
202+
),
203+
pytest.param("diffusion_model_cosine_velocity", marks=[pytest.mark.slow, pytest.mark.diffusion_model]),
204+
pytest.param("diffusion_model_cosine_F", marks=[pytest.mark.slow, pytest.mark.diffusion_model]),
205+
pytest.param(
206+
"diffusion_model_cosine_noise",
207+
marks=[
208+
pytest.mark.slow,
209+
pytest.mark.diffusion_model,
210+
pytest.mark.skip("noise predicition not testable without prior training for numerical reasons."),
211+
],
212+
),
213+
pytest.param("diffusion_model_cosine_velocity", marks=[pytest.mark.slow, pytest.mark.diffusion_model]),
136214
],
137215
scope="function",
138216
)

0 commit comments

Comments
 (0)