1+ import pytest
12import keras
23from tests .utils import check_combination_simulator_adapter
34
@@ -16,3 +17,92 @@ def test_approximator_sample(approximator, simulator, batch_size, adapter):
1617 samples = approximator .sample (num_samples = 2 , conditions = data )
1718
1819 assert isinstance (samples , dict )
20+
21+
22+ @pytest .mark .parametrize ("inference_network_type" , ["flow_matching" , "diffusion_model" ])
23+ @pytest .mark .parametrize ("summary_network_type" , ["none" , "deep_set" , "set_transformer" , "time_series" ])
24+ @pytest .mark .parametrize ("method" , ["euler" , "rk45" , "euler_maruyama" ])
25+ def test_approximator_sample_with_integration_methods (
26+ inference_network_type , summary_network_type , method , simulator , adapter
27+ ):
28+ """Test approximator sampling with different integration methods and summary networks.
29+
30+ Tests flow matching and diffusion models with different ODE/SDE solvers:
31+ - euler, rk45: Available for both flow matching and diffusion models
32+ - euler_maruyama: Only for diffusion models (stochastic)
33+
34+ Also tests with different summary network types.
35+ """
36+ batch_size = 8 # Use smaller batch size for faster tests
37+ check_combination_simulator_adapter (simulator , adapter )
38+
39+ # Skip euler_maruyama for flow matching (deterministic model)
40+ if inference_network_type == "flow_matching" and method == "euler_maruyama" :
41+ pytest .skip ("euler_maruyama is only available for diffusion models" )
42+
43+ # Create inference network based on type
44+ if inference_network_type == "flow_matching" :
45+ from bayesflow .networks import FlowMatching , MLP
46+
47+ inference_network = FlowMatching (
48+ subnet = MLP (widths = [32 , 32 ]),
49+ integrate_kwargs = {"steps" : 10 }, # Use fewer steps for faster tests
50+ )
51+ elif inference_network_type == "diffusion_model" :
52+ from bayesflow .networks import DiffusionModel , MLP
53+
54+ inference_network = DiffusionModel (
55+ subnet = MLP (widths = [32 , 32 ]),
56+ integrate_kwargs = {"steps" : 10 }, # Use fewer steps for faster tests
57+ )
58+ else :
59+ pytest .skip (f"Unsupported inference network type: { inference_network_type } " )
60+
61+ # Create summary network based on type
62+ summary_network = None
63+ if summary_network_type != "none" :
64+ if summary_network_type == "deep_set" :
65+ from bayesflow .networks import DeepSet , MLP
66+
67+ summary_network = DeepSet (subnet = MLP (widths = [16 , 16 ]))
68+ elif summary_network_type == "set_transformer" :
69+ from bayesflow .networks import SetTransformer
70+
71+ summary_network = SetTransformer (embed_dims = [16 , 16 ], mlp_widths = [16 , 16 ])
72+ elif summary_network_type == "time_series" :
73+ from bayesflow .networks import TimeSeriesNetwork
74+
75+ summary_network = TimeSeriesNetwork (subnet_kwargs = {"widths" : [16 , 16 ]}, cell_type = "lstm" )
76+ else :
77+ pytest .skip (f"Unsupported summary network type: { summary_network_type } " )
78+
79+ # Update adapter to include summary variables if summary network is present
80+ from bayesflow import ContinuousApproximator
81+
82+ adapter = ContinuousApproximator .build_adapter (
83+ inference_variables = ["mean" , "std" ],
84+ summary_variables = ["x" ], # Use x as summary variable for testing
85+ )
86+
87+ # Create approximator
88+ from bayesflow import ContinuousApproximator
89+
90+ approximator = ContinuousApproximator (
91+ adapter = adapter , inference_network = inference_network , summary_network = summary_network
92+ )
93+
94+ # Generate test data
95+ num_batches = 2 # Use fewer batches for faster tests
96+ data = simulator .sample ((num_batches * batch_size ,))
97+
98+ # Build approximator
99+ batch = adapter (data )
100+ batch = keras .tree .map_structure (keras .ops .convert_to_tensor , batch )
101+ batch_shapes = keras .tree .map_structure (keras .ops .shape , batch )
102+ approximator .build (batch_shapes )
103+
104+ # Test sampling with the specified method
105+ samples = approximator .sample (num_samples = 2 , conditions = data , method = method )
106+
107+ # Verify results
108+ assert isinstance (samples , dict )
0 commit comments