diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py index 5bd7c9dff..eddb8cf09 100644 --- a/bayesflow/networks/transformers/mab.py +++ b/bayesflow/networks/transformers/mab.py @@ -3,7 +3,7 @@ from bayesflow.networks import MLP from bayesflow.types import Tensor -from bayesflow.utils import layer_kwargs +from bayesflow.utils import layer_kwargs, filter_kwargs from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable @@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) - """ h = self.input_projector(seq_x) + self.attention( - query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs + query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call) ) if self.ln_pre is not None: h = self.ln_pre(h, training=training) diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index d0d748067..bd8290272 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -147,7 +147,7 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: out : Tensor Output of shape (batch_size, set_size, output_dim) """ - summary = self.attention_blocks(input_set, training=training, **kwargs) + summary = self.attention_blocks(input_set, training=training) summary = self.pooling_by_attention(summary, training=training, **kwargs) summary = self.output_projector(summary) return summary diff --git a/tests/test_approximators/test_sample.py b/tests/test_approximators/test_sample.py index c62ffc581..e76b72a40 100644 --- a/tests/test_approximators/test_sample.py +++ b/tests/test_approximators/test_sample.py @@ -1,3 +1,4 @@ +import pytest import keras from tests.utils import check_combination_simulator_adapter @@ -16,3 +17,92 @@ def test_approximator_sample(approximator, simulator, batch_size, adapter): samples = approximator.sample(num_samples=2, conditions=data) assert isinstance(samples, dict) + + +@pytest.mark.parametrize("inference_network_type", ["flow_matching", "diffusion_model"]) +@pytest.mark.parametrize("summary_network_type", ["none", "deep_set", "set_transformer", "time_series"]) +@pytest.mark.parametrize("method", ["euler", "rk45", "euler_maruyama"]) +def test_approximator_sample_with_integration_methods( + inference_network_type, summary_network_type, method, simulator, adapter +): + """Test approximator sampling with different integration methods and summary networks. + + Tests flow matching and diffusion models with different ODE/SDE solvers: + - euler, rk45: Available for both flow matching and diffusion models + - euler_maruyama: Only for diffusion models (stochastic) + + Also tests with different summary network types. + """ + batch_size = 8 # Use smaller batch size for faster tests + check_combination_simulator_adapter(simulator, adapter) + + # Skip euler_maruyama for flow matching (deterministic model) + if inference_network_type == "flow_matching" and method == "euler_maruyama": + pytest.skip("euler_maruyama is only available for diffusion models") + + # Create inference network based on type + if inference_network_type == "flow_matching": + from bayesflow.networks import FlowMatching, MLP + + inference_network = FlowMatching( + subnet=MLP(widths=[32, 32]), + integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests + ) + elif inference_network_type == "diffusion_model": + from bayesflow.networks import DiffusionModel, MLP + + inference_network = DiffusionModel( + subnet=MLP(widths=[32, 32]), + integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests + ) + else: + pytest.skip(f"Unsupported inference network type: {inference_network_type}") + + # Create summary network based on type + summary_network = None + if summary_network_type != "none": + if summary_network_type == "deep_set": + from bayesflow.networks import DeepSet, MLP + + summary_network = DeepSet(subnet=MLP(widths=[16, 16])) + elif summary_network_type == "set_transformer": + from bayesflow.networks import SetTransformer + + summary_network = SetTransformer(embed_dims=[16, 16], mlp_widths=[16, 16]) + elif summary_network_type == "time_series": + from bayesflow.networks import TimeSeriesNetwork + + summary_network = TimeSeriesNetwork(subnet_kwargs={"widths": [16, 16]}, cell_type="lstm") + else: + pytest.skip(f"Unsupported summary network type: {summary_network_type}") + + # Update adapter to include summary variables if summary network is present + from bayesflow import ContinuousApproximator + + adapter = ContinuousApproximator.build_adapter( + inference_variables=["mean", "std"], + summary_variables=["x"], # Use x as summary variable for testing + ) + + # Create approximator + from bayesflow import ContinuousApproximator + + approximator = ContinuousApproximator( + adapter=adapter, inference_network=inference_network, summary_network=summary_network + ) + + # Generate test data + num_batches = 2 # Use fewer batches for faster tests + data = simulator.sample((num_batches * batch_size,)) + + # Build approximator + batch = adapter(data) + batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch) + batch_shapes = keras.tree.map_structure(keras.ops.shape, batch) + approximator.build(batch_shapes) + + # Test sampling with the specified method + samples = approximator.sample(num_samples=2, conditions=data, method=method) + + # Verify results + assert isinstance(samples, dict)