11import pytest
22
3+ from bayesflow .networks .sequential import Sequential
34from bayesflow .networks import MLP
5+ from bayesflow .utils .tensor_utils import concatenate_valid
46
57
68@pytest .fixture ()
@@ -15,6 +17,29 @@ def diffusion_model_edm_F():
1517 )
1618
1719
20+ class ConcatenateMLP (Sequential ):
21+ def __init__ (self , widths ):
22+ super ().__init__ ()
23+ self .mlp = MLP (widths )
24+
25+ def call (self , x , t , conditions = None , training = False ):
26+ con = concatenate_valid ([x , t , conditions ], axis = - 1 )
27+ return self .mlp (con )
28+
29+
30+ @pytest .fixture ()
31+ def diffusion_model_edm_F_subnet_concatenate ():
32+ from bayesflow .networks import DiffusionModel
33+
34+ return DiffusionModel (
35+ subnet = ConcatenateMLP ([8 , 8 ]),
36+ integrate_kwargs = {"method" : "rk45" , "steps" : 250 },
37+ noise_schedule = "edm" ,
38+ prediction_type = "F" ,
39+ concatenate_subnet_input = False ,
40+ )
41+
42+
1843@pytest .fixture ()
1944def diffusion_model_edm_velocity ():
2045 from bayesflow .networks import DiffusionModel
@@ -85,13 +110,29 @@ def flow_matching():
85110 )
86111
87112
113+ @pytest .fixture ()
114+ def flow_matching_subnet_concatenate ():
115+ from bayesflow .networks import FlowMatching
116+
117+ return FlowMatching (
118+ subnet = ConcatenateMLP ([8 , 8 ]), integrate_kwargs = {"method" : "rk45" , "steps" : 100 }, concatenate_subnet_input = False
119+ )
120+
121+
88122@pytest .fixture ()
89123def consistency_model ():
90124 from bayesflow .networks import ConsistencyModel
91125
92126 return ConsistencyModel (total_steps = 100 , subnet = MLP ([8 , 8 ]))
93127
94128
129+ @pytest .fixture ()
130+ def consistency_model_subnet_concatenate ():
131+ from bayesflow .networks import ConsistencyModel
132+
133+ return ConsistencyModel (total_steps = 100 , subnet = ConcatenateMLP ([8 , 8 ]), concatenate_subnet_input = False )
134+
135+
95136@pytest .fixture ()
96137def affine_coupling_flow ():
97138 from bayesflow .networks import CouplingFlow
@@ -189,14 +230,17 @@ def inference_network_subnet(request):
189230 "affine_coupling_flow" ,
190231 "spline_coupling_flow" ,
191232 "flow_matching" ,
233+ pytest .param ("flow_matching_subnet_concatenate" ),
192234 "free_form_flow" ,
193235 "consistency_model" ,
236+ pytest .param ("consistency_model_subnet_concatenate" ),
194237 pytest .param ("diffusion_model_edm_F" ),
238+ pytest .param ("diffusion_model_edm_F_subnet_concatenate" ),
195239 pytest .param (
196240 "diffusion_model_edm_noise" ,
197241 marks = [
198242 pytest .mark .slow ,
199- pytest .mark .skip ("noise predicition not testable without prior training for numerical reasons." ),
243+ pytest .mark .skip ("noise prediction not testable without prior training for numerical reasons." ),
200244 ],
201245 ),
202246 pytest .param ("diffusion_model_cosine_velocity" , marks = pytest .mark .slow ),
@@ -211,7 +255,7 @@ def inference_network_subnet(request):
211255 "diffusion_model_cosine_noise" ,
212256 marks = [
213257 pytest .mark .slow ,
214- pytest .mark .skip ("noise predicition not testable without prior training for numerical reasons." ),
258+ pytest .mark .skip ("noise prediction not testable without prior training for numerical reasons." ),
215259 ],
216260 ),
217261 pytest .param (
0 commit comments