From ca7f3bdaf9700d2fd6268c844c4abf6fb5963139 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 16 Sep 2025 12:21:13 +0200 Subject: [PATCH 1/4] fix kwargs in sample --- bayesflow/networks/transformers/mab.py | 4 +- .../networks/transformers/set_transformer.py | 10 ++- tests/test_approximators/test_sample.py | 90 +++++++++++++++++++ 3 files changed, 99 insertions(+), 5 deletions(-) diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py index 5bd7c9dff..eddb8cf09 100644 --- a/bayesflow/networks/transformers/mab.py +++ b/bayesflow/networks/transformers/mab.py @@ -3,7 +3,7 @@ from bayesflow.networks import MLP from bayesflow.types import Tensor -from bayesflow.utils import layer_kwargs +from bayesflow.utils import layer_kwargs, filter_kwargs from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable @@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) - """ h = self.input_projector(seq_x) + self.attention( - query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs + query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call) ) if self.ln_pre is not None: h = self.ln_pre(h, training=training) diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index d0d748067..94690f3ef 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -1,7 +1,7 @@ import keras from bayesflow.types import Tensor -from bayesflow.utils import check_lengths_same +from bayesflow.utils import check_lengths_same, filter_kwargs from bayesflow.utils.serialization import serializable from ..summary_network import SummaryNetwork @@ -147,7 +147,11 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: out : Tensor Output of shape (batch_size, set_size, output_dim) """ - summary = self.attention_blocks(input_set, training=training, **kwargs) - summary = self.pooling_by_attention(summary, training=training, **kwargs) + summary = self.attention_blocks( + input_set, training=training, **filter_kwargs(kwargs, self.attention_blocks.call) + ) + summary = self.pooling_by_attention( + summary, training=training, **filter_kwargs(kwargs, self.pooling_by_attention.call) + ) summary = self.output_projector(summary) return summary diff --git a/tests/test_approximators/test_sample.py b/tests/test_approximators/test_sample.py index c62ffc581..e76b72a40 100644 --- a/tests/test_approximators/test_sample.py +++ b/tests/test_approximators/test_sample.py @@ -1,3 +1,4 @@ +import pytest import keras from tests.utils import check_combination_simulator_adapter @@ -16,3 +17,92 @@ def test_approximator_sample(approximator, simulator, batch_size, adapter): samples = approximator.sample(num_samples=2, conditions=data) assert isinstance(samples, dict) + + +@pytest.mark.parametrize("inference_network_type", ["flow_matching", "diffusion_model"]) +@pytest.mark.parametrize("summary_network_type", ["none", "deep_set", "set_transformer", "time_series"]) +@pytest.mark.parametrize("method", ["euler", "rk45", "euler_maruyama"]) +def test_approximator_sample_with_integration_methods( + inference_network_type, summary_network_type, method, simulator, adapter +): + """Test approximator sampling with different integration methods and summary networks. + + Tests flow matching and diffusion models with different ODE/SDE solvers: + - euler, rk45: Available for both flow matching and diffusion models + - euler_maruyama: Only for diffusion models (stochastic) + + Also tests with different summary network types. + """ + batch_size = 8 # Use smaller batch size for faster tests + check_combination_simulator_adapter(simulator, adapter) + + # Skip euler_maruyama for flow matching (deterministic model) + if inference_network_type == "flow_matching" and method == "euler_maruyama": + pytest.skip("euler_maruyama is only available for diffusion models") + + # Create inference network based on type + if inference_network_type == "flow_matching": + from bayesflow.networks import FlowMatching, MLP + + inference_network = FlowMatching( + subnet=MLP(widths=[32, 32]), + integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests + ) + elif inference_network_type == "diffusion_model": + from bayesflow.networks import DiffusionModel, MLP + + inference_network = DiffusionModel( + subnet=MLP(widths=[32, 32]), + integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests + ) + else: + pytest.skip(f"Unsupported inference network type: {inference_network_type}") + + # Create summary network based on type + summary_network = None + if summary_network_type != "none": + if summary_network_type == "deep_set": + from bayesflow.networks import DeepSet, MLP + + summary_network = DeepSet(subnet=MLP(widths=[16, 16])) + elif summary_network_type == "set_transformer": + from bayesflow.networks import SetTransformer + + summary_network = SetTransformer(embed_dims=[16, 16], mlp_widths=[16, 16]) + elif summary_network_type == "time_series": + from bayesflow.networks import TimeSeriesNetwork + + summary_network = TimeSeriesNetwork(subnet_kwargs={"widths": [16, 16]}, cell_type="lstm") + else: + pytest.skip(f"Unsupported summary network type: {summary_network_type}") + + # Update adapter to include summary variables if summary network is present + from bayesflow import ContinuousApproximator + + adapter = ContinuousApproximator.build_adapter( + inference_variables=["mean", "std"], + summary_variables=["x"], # Use x as summary variable for testing + ) + + # Create approximator + from bayesflow import ContinuousApproximator + + approximator = ContinuousApproximator( + adapter=adapter, inference_network=inference_network, summary_network=summary_network + ) + + # Generate test data + num_batches = 2 # Use fewer batches for faster tests + data = simulator.sample((num_batches * batch_size,)) + + # Build approximator + batch = adapter(data) + batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch) + batch_shapes = keras.tree.map_structure(keras.ops.shape, batch) + approximator.build(batch_shapes) + + # Test sampling with the specified method + samples = approximator.sample(num_samples=2, conditions=data, method=method) + + # Verify results + assert isinstance(samples, dict) From 2c161c6d85f675eac9347d036cbc93b76524b323 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 16 Sep 2025 15:32:20 +0200 Subject: [PATCH 2/4] fix kwargs in set transformer --- bayesflow/networks/transformers/isab.py | 1 + bayesflow/networks/transformers/mab.py | 4 ++-- bayesflow/networks/transformers/set_transformer.py | 10 +++------- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/bayesflow/networks/transformers/isab.py b/bayesflow/networks/transformers/isab.py index 03f15a561..1b763c2b3 100644 --- a/bayesflow/networks/transformers/isab.py +++ b/bayesflow/networks/transformers/isab.py @@ -107,5 +107,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: batch_size = keras.ops.shape(input_set)[0] inducing_points_expanded = keras.ops.expand_dims(self.inducing_points, axis=0) inducing_points_tiled = keras.ops.tile(inducing_points_expanded, [batch_size, 1, 1]) + print(kwargs) h = self.mab0(inducing_points_tiled, input_set, training=training, **kwargs) return self.mab1(input_set, h, training=training, **kwargs) diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py index eddb8cf09..5bd7c9dff 100644 --- a/bayesflow/networks/transformers/mab.py +++ b/bayesflow/networks/transformers/mab.py @@ -3,7 +3,7 @@ from bayesflow.networks import MLP from bayesflow.types import Tensor -from bayesflow.utils import layer_kwargs, filter_kwargs +from bayesflow.utils import layer_kwargs from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable @@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) - """ h = self.input_projector(seq_x) + self.attention( - query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call) + query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs ) if self.ln_pre is not None: h = self.ln_pre(h, training=training) diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index 94690f3ef..7e9da76ea 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -1,7 +1,7 @@ import keras from bayesflow.types import Tensor -from bayesflow.utils import check_lengths_same, filter_kwargs +from bayesflow.utils import check_lengths_same from bayesflow.utils.serialization import serializable from ..summary_network import SummaryNetwork @@ -147,11 +147,7 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: out : Tensor Output of shape (batch_size, set_size, output_dim) """ - summary = self.attention_blocks( - input_set, training=training, **filter_kwargs(kwargs, self.attention_blocks.call) - ) - summary = self.pooling_by_attention( - summary, training=training, **filter_kwargs(kwargs, self.pooling_by_attention.call) - ) + summary = self.attention_blocks(input_set, training=training) + summary = self.pooling_by_attention(summary, training=training) summary = self.output_projector(summary) return summary From 9d4c1a1c605c7e226ea72f97321e1a55c7e718ac Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 16 Sep 2025 15:37:38 +0200 Subject: [PATCH 3/4] fix kwargs in set transformer --- bayesflow/networks/transformers/mab.py | 4 ++-- bayesflow/networks/transformers/set_transformer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py index 5bd7c9dff..eddb8cf09 100644 --- a/bayesflow/networks/transformers/mab.py +++ b/bayesflow/networks/transformers/mab.py @@ -3,7 +3,7 @@ from bayesflow.networks import MLP from bayesflow.types import Tensor -from bayesflow.utils import layer_kwargs +from bayesflow.utils import layer_kwargs, filter_kwargs from bayesflow.utils.decorators import sanitize_input_shape from bayesflow.utils.serialization import serializable @@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) - """ h = self.input_projector(seq_x) + self.attention( - query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs + query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call) ) if self.ln_pre is not None: h = self.ln_pre(h, training=training) diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index 7e9da76ea..bd8290272 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -148,6 +148,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: Output of shape (batch_size, set_size, output_dim) """ summary = self.attention_blocks(input_set, training=training) - summary = self.pooling_by_attention(summary, training=training) + summary = self.pooling_by_attention(summary, training=training, **kwargs) summary = self.output_projector(summary) return summary From ea0659d14962e4b423a42e5bbf53dd79f1797eb9 Mon Sep 17 00:00:00 2001 From: arrjon Date: Tue, 16 Sep 2025 15:38:40 +0200 Subject: [PATCH 4/4] remove print --- bayesflow/networks/transformers/isab.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bayesflow/networks/transformers/isab.py b/bayesflow/networks/transformers/isab.py index 1b763c2b3..03f15a561 100644 --- a/bayesflow/networks/transformers/isab.py +++ b/bayesflow/networks/transformers/isab.py @@ -107,6 +107,5 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: batch_size = keras.ops.shape(input_set)[0] inducing_points_expanded = keras.ops.expand_dims(self.inducing_points, axis=0) inducing_points_tiled = keras.ops.tile(inducing_points_expanded, [batch_size, 1, 1]) - print(kwargs) h = self.mab0(inducing_points_tiled, input_set, training=training, **kwargs) return self.mab1(input_set, h, training=training, **kwargs)