@@ -61,12 +61,12 @@ def get_config(self):
6161
6262
6363@pytest .fixture ()
64- def diffusion_model_edm_F_subnet_concatenate ():
64+ def diffusion_model_edm_F_subnet_separate_inputs ():
6565 from bayesflow .networks import DiffusionModel
6666
6767 return DiffusionModel (
6868 subnet = ConcatenateMLP ([8 , 8 ]),
69- integrate_kwargs = {"method" : "rk45" , "steps" : 250 },
69+ integrate_kwargs = {"method" : "rk45" , "steps" : 4 },
7070 noise_schedule = "edm" ,
7171 prediction_type = "F" ,
7272 concatenate_subnet_input = False ,
@@ -144,11 +144,11 @@ def flow_matching():
144144
145145
146146@pytest .fixture ()
147- def flow_matching_subnet_concatenate ():
147+ def flow_matching_subnet_separate_inputs ():
148148 from bayesflow .networks import FlowMatching
149149
150150 return FlowMatching (
151- subnet = ConcatenateMLP ([8 , 8 ]), integrate_kwargs = {"method" : "rk45" , "steps" : 100 }, concatenate_subnet_input = False
151+ subnet = ConcatenateMLP ([8 , 8 ]), integrate_kwargs = {"method" : "rk45" , "steps" : 4 }, concatenate_subnet_input = False
152152 )
153153
154154
@@ -160,10 +160,10 @@ def consistency_model():
160160
161161
162162@pytest .fixture ()
163- def consistency_model_subnet_concatenate ():
163+ def consistency_model_subnet_separate_inputs ():
164164 from bayesflow .networks import ConsistencyModel
165165
166- return ConsistencyModel (total_steps = 100 , subnet = ConcatenateMLP ([8 , 8 ]), concatenate_subnet_input = False )
166+ return ConsistencyModel (total_steps = 4 , subnet = ConcatenateMLP ([8 , 8 ]), concatenate_subnet_input = False )
167167
168168
169169@pytest .fixture ()
@@ -230,12 +230,9 @@ def typical_point_inference_network_subnet():
230230 "affine_coupling_flow" ,
231231 "spline_coupling_flow" ,
232232 "flow_matching" ,
233- pytest .param ("flow_matching_subnet_concatenate" ),
234233 "free_form_flow" ,
235234 "consistency_model" ,
236- pytest .param ("consistency_model_subnet_concatenate" ),
237235 pytest .param ("diffusion_model_edm_F" ),
238- pytest .param ("diffusion_model_edm_F_subnet_concatenate" ),
239236 pytest .param ("diffusion_model_edm_noise" , marks = pytest .mark .slow ),
240237 pytest .param ("diffusion_model_cosine_velocity" , marks = pytest .mark .slow ),
241238 pytest .param ("diffusion_model_cosine_F" , marks = pytest .mark .slow ),
@@ -263,16 +260,12 @@ def inference_network_subnet(request):
263260
264261@pytest .fixture (
265262 params = [
266- pytest .param ("diffusion_model_edm_F_subnet_concatenate" ),
267263 "affine_coupling_flow" ,
268264 "spline_coupling_flow" ,
269265 "flow_matching" ,
270- pytest .param ("flow_matching_subnet_concatenate" ),
271266 "free_form_flow" ,
272267 "consistency_model" ,
273- pytest .param ("consistency_model_subnet_concatenate" ),
274268 pytest .param ("diffusion_model_edm_F" ),
275- pytest .param ("diffusion_model_edm_F_subnet_concatenate" ),
276269 pytest .param (
277270 "diffusion_model_edm_noise" ,
278271 marks = [
@@ -309,6 +302,18 @@ def generative_inference_network(request):
309302 return request .getfixturevalue (request .param )
310303
311304
305+ @pytest .fixture (
306+ params = [
307+ pytest .param ("flow_matching_subnet_separate_inputs" ),
308+ pytest .param ("consistency_model_subnet_separate_inputs" ),
309+ pytest .param ("diffusion_model_edm_F_subnet_separate_inputs" ),
310+ ],
311+ scope = "function" ,
312+ )
313+ def inference_network_subnet_separate_inputs (request ):
314+ return request .getfixturevalue (request .param )
315+
316+
312317@pytest .fixture (scope = "function" )
313318def time_series_network (summary_dim ):
314319 from bayesflow .networks import TimeSeriesNetwork
0 commit comments