Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bayesflow/networks/transformers/mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from bayesflow.networks import MLP
from bayesflow.types import Tensor
from bayesflow.utils import layer_kwargs
from bayesflow.utils import layer_kwargs, filter_kwargs
from bayesflow.utils.decorators import sanitize_input_shape
from bayesflow.utils.serialization import serializable

Expand Down Expand Up @@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) -
"""

h = self.input_projector(seq_x) + self.attention(
query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs
query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call)
)
if self.ln_pre is not None:
h = self.ln_pre(h, training=training)
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/networks/transformers/set_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
out : Tensor
Output of shape (batch_size, set_size, output_dim)
"""
summary = self.attention_blocks(input_set, training=training, **kwargs)
summary = self.attention_blocks(input_set, training=training)
summary = self.pooling_by_attention(summary, training=training, **kwargs)
summary = self.output_projector(summary)
return summary
90 changes: 90 additions & 0 deletions tests/test_approximators/test_sample.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import keras
from tests.utils import check_combination_simulator_adapter

Expand All @@ -16,3 +17,92 @@ def test_approximator_sample(approximator, simulator, batch_size, adapter):
samples = approximator.sample(num_samples=2, conditions=data)

assert isinstance(samples, dict)


@pytest.mark.parametrize("inference_network_type", ["flow_matching", "diffusion_model"])
@pytest.mark.parametrize("summary_network_type", ["none", "deep_set", "set_transformer", "time_series"])
@pytest.mark.parametrize("method", ["euler", "rk45", "euler_maruyama"])
def test_approximator_sample_with_integration_methods(
inference_network_type, summary_network_type, method, simulator, adapter
):
"""Test approximator sampling with different integration methods and summary networks.

Tests flow matching and diffusion models with different ODE/SDE solvers:
- euler, rk45: Available for both flow matching and diffusion models
- euler_maruyama: Only for diffusion models (stochastic)

Also tests with different summary network types.
"""
batch_size = 8 # Use smaller batch size for faster tests
check_combination_simulator_adapter(simulator, adapter)

# Skip euler_maruyama for flow matching (deterministic model)
if inference_network_type == "flow_matching" and method == "euler_maruyama":
pytest.skip("euler_maruyama is only available for diffusion models")

# Create inference network based on type
if inference_network_type == "flow_matching":
from bayesflow.networks import FlowMatching, MLP

inference_network = FlowMatching(
subnet=MLP(widths=[32, 32]),
integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests
)
elif inference_network_type == "diffusion_model":
from bayesflow.networks import DiffusionModel, MLP

inference_network = DiffusionModel(
subnet=MLP(widths=[32, 32]),
integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests
)
else:
pytest.skip(f"Unsupported inference network type: {inference_network_type}")

# Create summary network based on type
summary_network = None
if summary_network_type != "none":
if summary_network_type == "deep_set":
from bayesflow.networks import DeepSet, MLP

summary_network = DeepSet(subnet=MLP(widths=[16, 16]))
elif summary_network_type == "set_transformer":
from bayesflow.networks import SetTransformer

summary_network = SetTransformer(embed_dims=[16, 16], mlp_widths=[16, 16])
elif summary_network_type == "time_series":
from bayesflow.networks import TimeSeriesNetwork

summary_network = TimeSeriesNetwork(subnet_kwargs={"widths": [16, 16]}, cell_type="lstm")
else:
pytest.skip(f"Unsupported summary network type: {summary_network_type}")

# Update adapter to include summary variables if summary network is present
from bayesflow import ContinuousApproximator

adapter = ContinuousApproximator.build_adapter(
inference_variables=["mean", "std"],
summary_variables=["x"], # Use x as summary variable for testing
)

# Create approximator
from bayesflow import ContinuousApproximator

approximator = ContinuousApproximator(
adapter=adapter, inference_network=inference_network, summary_network=summary_network
)

# Generate test data
num_batches = 2 # Use fewer batches for faster tests
data = simulator.sample((num_batches * batch_size,))

# Build approximator
batch = adapter(data)
batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch)
batch_shapes = keras.tree.map_structure(keras.ops.shape, batch)
approximator.build(batch_shapes)

# Test sampling with the specified method
samples = approximator.sample(num_samples=2, conditions=data, method=method)

# Verify results
assert isinstance(samples, dict)
Loading