@@ -36,26 +36,174 @@ def data_dir(request, commit, from_commit, tmp_path_factory):
3636 return Path (tmp_path_factory .mktemp ("_compatibility_data" ))
3737
3838
39+ # reduce number of test configurations
40+ @pytest .fixture (params = [None , 3 ])
41+ def conditions_size (request ):
42+ return request .param
43+
44+
45+ @pytest .fixture (params = [1 , 2 ])
46+ def summary_dim (request ):
47+ return request .param
48+
49+
50+ @pytest .fixture (params = [4 ])
51+ def feature_size (request ):
52+ return request .param
53+
54+
55+ # Generic fixtures for use as input to the tested classes.
56+ # The classes to test are constructed in the respective subdirectories, to allow for more thorough configuation.
57+ @pytest .fixture (params = [None , "all" ])
58+ def standardize (request ):
59+ return request .param
60+
61+
62+ @pytest .fixture ()
63+ def adapter (request ):
64+ import bayesflow as bf
65+
66+ match request .param :
67+ case "summary" :
68+ return bf .Adapter .create_default ("parameters" ).rename ("observables" , "summary_variables" )
69+ case "direct" :
70+ return bf .Adapter .create_default ("parameters" ).rename ("observables" , "direct_conditions" )
71+ case "default" :
72+ return bf .Adapter .create_default ("parameters" )
73+ case "empty" :
74+ return bf .Adapter ()
75+ case None :
76+ return None
77+ case _:
78+ raise ValueError (f"Invalid request parameter for adapter: { request .param } " )
79+
80+
81+ @pytest .fixture (params = ["coupling_flow" , "flow_matching" ])
82+ def inference_network (request ):
83+ match request .param :
84+ case "coupling_flow" :
85+ from bayesflow .networks import CouplingFlow
86+
87+ return CouplingFlow (depth = 2 )
88+
89+ case "flow_matching" :
90+ from bayesflow .networks import FlowMatching
91+
92+ return FlowMatching (subnet_kwargs = dict (widths = (32 , 32 )), use_optimal_transport = False )
93+
94+ case None :
95+ return None
96+
97+ case _:
98+ raise ValueError (f"Invalid request parameter for inference_network: { request .param } " )
99+
100+
101+ @pytest .fixture (params = ["time_series_transformer" , "fusion_transformer" , "time_series_network" , "custom" ])
102+ def summary_network (request ):
103+ match request .param :
104+ case "time_series_transformer" :
105+ from bayesflow .networks import TimeSeriesTransformer
106+
107+ return TimeSeriesTransformer (embed_dims = (8 , 8 ), mlp_widths = (16 , 8 ), mlp_depths = (1 , 1 ))
108+
109+ case "fusion_transformer" :
110+ from bayesflow .networks import FusionTransformer
111+
112+ return FusionTransformer (
113+ embed_dims = (8 , 8 ), mlp_widths = (8 , 16 ), mlp_depths = (2 , 1 ), template_dim = 8 , bidirectional = False
114+ )
115+
116+ case "time_series_network" :
117+ from bayesflow .networks import TimeSeriesNetwork
118+
119+ return TimeSeriesNetwork (filters = 4 , skip_steps = 2 )
120+
121+ case "deep_set" :
122+ from bayesflow .networks import DeepSet
123+
124+ return DeepSet (summary_dim = 2 , depth = 1 )
125+
126+ case "custom" :
127+ from bayesflow .networks import SummaryNetwork
128+ from bayesflow .utils .serialization import serializable
129+ import keras
130+
131+ @serializable ("test" , disable_module_check = True )
132+ class Custom (SummaryNetwork ):
133+ def __init__ (self , ** kwargs ):
134+ super ().__init__ (** kwargs )
135+ self .inner = keras .Sequential ([keras .layers .LSTM (8 ), keras .layers .Dense (4 )])
136+
137+ def call (self , x , ** kwargs ):
138+ return self .inner (x , training = kwargs .get ("stage" ) == "training" )
139+
140+ return Custom ()
141+
142+ case "flatten" :
143+ # very simple summary network for fast training
144+ from bayesflow .networks import SummaryNetwork
145+ from bayesflow .utils .serialization import serializable
146+ import keras
147+
148+ @serializable ("test" , disable_module_check = True )
149+ class FlattenSummaryNetwork (SummaryNetwork ):
150+ def __init__ (self , ** kwargs ):
151+ super ().__init__ (** kwargs )
152+ self .inner = keras .layers .Flatten ()
153+
154+ def call (self , x , ** kwargs ):
155+ return self .inner (x , training = kwargs .get ("stage" ) == "training" )
156+
157+ return FlattenSummaryNetwork ()
158+
159+ case "fusion_network" :
160+ from bayesflow .networks import FusionNetwork , DeepSet
161+
162+ return FusionNetwork ({"a" : DeepSet (), "b" : keras .layers .Flatten ()}, head = keras .layers .Dense (2 ))
163+ case None :
164+ return None
165+ case _:
166+ raise ValueError (f"Invalid request parameter for summary_network: { request .param } " )
167+
168+
39169@pytest .fixture (params = ["sir" , "fusion" ])
40170def simulator (request ):
41- if request .param == "sir" :
42- from bayesflow .simulators import SIR
171+ match request .param :
172+ case "sir" :
173+ from bayesflow .simulators import SIR
174+
175+ return SIR ()
176+ case "lotka_volterra" :
177+ from bayesflow .simulators import LotkaVolterra
178+
179+ return LotkaVolterra ()
180+
181+ case "two_moons" :
182+ from bayesflow .simulators import TwoMoons
183+
184+ return TwoMoons ()
185+ case "normal" :
186+ from tests .utils .normal_simulator import NormalSimulator
43187
44- return SIR ()
45- elif request . param == "fusion" :
46- from bayesflow .simulators import Simulator
47- from bayesflow .types import Shape , Tensor
48- from bayesflow .utils .decorators import allow_batch_size
49- import numpy as np
188+ return NormalSimulator ()
189+ case "fusion" :
190+ from bayesflow .simulators import Simulator
191+ from bayesflow .types import Shape , Tensor
192+ from bayesflow .utils .decorators import allow_batch_size
193+ import numpy as np
50194
51- class FusionSimulator (Simulator ):
52- @allow_batch_size
53- def sample (self , batch_shape : Shape , num_observations : int = 4 ) -> dict [str , Tensor ]:
54- mean = np .random .normal (0.0 , 0.1 , size = batch_shape + (2 ,))
55- noise = np .random .standard_normal (batch_shape + (num_observations , 2 ))
195+ class FusionSimulator (Simulator ):
196+ @allow_batch_size
197+ def sample (self , batch_shape : Shape , num_observations : int = 4 ) -> dict [str , Tensor ]:
198+ mean = np .random .normal (0.0 , 0.1 , size = batch_shape + (2 ,))
199+ noise = np .random .standard_normal (batch_shape + (num_observations , 2 ))
56200
57- x = mean [:, None ] + noise
201+ x = mean [:, None ] + noise
58202
59- return dict (mean = mean , a = x , b = x )
203+ return dict (mean = mean , a = x , b = x )
60204
61- return FusionSimulator ()
205+ return FusionSimulator ()
206+ case None :
207+ return None
208+ case _:
209+ raise ValueError (f"Invalid request parameter for simulator: { request .param } " )
0 commit comments