44
55
66@pytest .fixture ()
7- def diffusion_model ():
7+ def diffusion_model_edm_F ():
88 from bayesflow .experimental import DiffusionModel
99
1010 return DiffusionModel (
1111 subnet_kwargs = {"widths" : [64 , 64 ]},
1212 integrate_kwargs = {"method" : "rk45" , "steps" : 100 },
13+ noise_schedule = "edm" ,
14+ prediction_type = "F" ,
1315 )
1416
1517
1618@pytest .fixture ()
17- def diffusion_model_subnet ( subnet ):
19+ def diffusion_model_edm_velocity ( ):
1820 from bayesflow .experimental import DiffusionModel
1921
20- return DiffusionModel (subnet = subnet )
22+ return DiffusionModel (
23+ subnet_kwargs = {"widths" : [64 , 64 ]},
24+ integrate_kwargs = {"method" : "rk45" , "steps" : 100 },
25+ noise_schedule = "edm" ,
26+ prediction_type = "velocity" ,
27+ )
28+
29+
30+ @pytest .fixture ()
31+ def diffusion_model_edm_noise ():
32+ from bayesflow .experimental import DiffusionModel
33+
34+ return DiffusionModel (
35+ subnet_kwargs = {"widths" : [64 , 64 ]},
36+ integrate_kwargs = {"method" : "rk45" , "steps" : 100 },
37+ noise_schedule = "edm" ,
38+ prediction_type = "noise" ,
39+ )
40+
41+
42+ @pytest .fixture ()
43+ def diffusion_model_cosine_F ():
44+ from bayesflow .experimental import DiffusionModel
45+
46+ return DiffusionModel (
47+ subnet_kwargs = {"widths" : [64 , 64 ]},
48+ integrate_kwargs = {"method" : "rk45" , "steps" : 100 },
49+ noise_schedule = "cosine" ,
50+ prediction_type = "F" ,
51+ )
52+
53+
54+ @pytest .fixture ()
55+ def diffusion_model_cosine_velocity ():
56+ from bayesflow .experimental import DiffusionModel
57+
58+ return DiffusionModel (
59+ subnet_kwargs = {"widths" : [64 , 64 ]},
60+ integrate_kwargs = {"method" : "rk45" , "steps" : 100 },
61+ noise_schedule = "cosine" ,
62+ prediction_type = "velocity" ,
63+ )
64+
65+
66+ @pytest .fixture ()
67+ def diffusion_model_cosine_noise ():
68+ from bayesflow .experimental import DiffusionModel
69+
70+ return DiffusionModel (
71+ subnet_kwargs = {"widths" : [64 , 64 ]},
72+ integrate_kwargs = {"method" : "rk45" , "steps" : 100 },
73+ noise_schedule = "cosine" ,
74+ prediction_type = "noise" ,
75+ )
2176
2277
2378@pytest .fixture ()
@@ -101,9 +156,14 @@ def typical_point_inference_network_subnet():
101156 "affine_coupling_flow" ,
102157 "spline_coupling_flow" ,
103158 "flow_matching" ,
104- "diffusion_model" ,
105159 "free_form_flow" ,
106160 "consistency_model" ,
161+ pytest .param ("diffusion_model_edm_F" , marks = pytest .mark .diffusion_model ),
162+ pytest .param ("diffusion_model_edm_noise" , marks = [pytest .mark .slow , pytest .mark .diffusion_model ]),
163+ pytest .param ("diffusion_model_cosine_velocity" , marks = [pytest .mark .slow , pytest .mark .diffusion_model ]),
164+ pytest .param ("diffusion_model_cosine_F" , marks = [pytest .mark .slow , pytest .mark .diffusion_model ]),
165+ pytest .param ("diffusion_model_cosine_noise" , marks = [pytest .mark .slow , pytest .mark .diffusion_model ]),
166+ pytest .param ("diffusion_model_cosine_velocity" , marks = [pytest .mark .slow , pytest .mark .diffusion_model ]),
107167 ],
108168 scope = "function" ,
109169)
@@ -116,7 +176,6 @@ def inference_network(request):
116176 "typical_point_inference_network_subnet" ,
117177 "coupling_flow_subnet" ,
118178 "flow_matching_subnet" ,
119- "diffusion_model_subnet" ,
120179 "free_form_flow_subnet" ,
121180 ],
122181 scope = "function" ,
@@ -130,9 +189,28 @@ def inference_network_subnet(request):
130189 "affine_coupling_flow" ,
131190 "spline_coupling_flow" ,
132191 "flow_matching" ,
133- "diffusion_model" ,
134192 "free_form_flow" ,
135193 "consistency_model" ,
194+ pytest .param ("diffusion_model_edm_F" , marks = pytest .mark .diffusion_model ),
195+ pytest .param (
196+ "diffusion_model_edm_noise" ,
197+ marks = [
198+ pytest .mark .slow ,
199+ pytest .mark .diffusion_model ,
200+ pytest .mark .skip ("noise predicition not testable without prior training for numerical reasons." ),
201+ ],
202+ ),
203+ pytest .param ("diffusion_model_cosine_velocity" , marks = [pytest .mark .slow , pytest .mark .diffusion_model ]),
204+ pytest .param ("diffusion_model_cosine_F" , marks = [pytest .mark .slow , pytest .mark .diffusion_model ]),
205+ pytest .param (
206+ "diffusion_model_cosine_noise" ,
207+ marks = [
208+ pytest .mark .slow ,
209+ pytest .mark .diffusion_model ,
210+ pytest .mark .skip ("noise predicition not testable without prior training for numerical reasons." ),
211+ ],
212+ ),
213+ pytest .param ("diffusion_model_cosine_velocity" , marks = [pytest .mark .slow , pytest .mark .diffusion_model ]),
136214 ],
137215 scope = "function" ,
138216)
0 commit comments