Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bayesflow/networks/transformers/mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from bayesflow.networks import MLP
from bayesflow.types import Tensor
from bayesflow.utils import layer_kwargs
from bayesflow.utils import layer_kwargs, filter_kwargs
from bayesflow.utils.decorators import sanitize_input_shape
from bayesflow.utils.serialization import serializable

Expand Down Expand Up @@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) -
"""

h = self.input_projector(seq_x) + self.attention(
query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs
query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call)
)
if self.ln_pre is not None:
h = self.ln_pre(h, training=training)
Expand Down
10 changes: 7 additions & 3 deletions bayesflow/networks/transformers/set_transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import keras

from bayesflow.types import Tensor
from bayesflow.utils import check_lengths_same
from bayesflow.utils import check_lengths_same, filter_kwargs
from bayesflow.utils.serialization import serializable

from ..summary_network import SummaryNetwork
Expand Down Expand Up @@ -147,7 +147,11 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
out : Tensor
Output of shape (batch_size, set_size, output_dim)
"""
summary = self.attention_blocks(input_set, training=training, **kwargs)
summary = self.pooling_by_attention(summary, training=training, **kwargs)
summary = self.attention_blocks(
input_set, training=training, **filter_kwargs(kwargs, self.attention_blocks.call)
)
summary = self.pooling_by_attention(
summary, training=training, **filter_kwargs(kwargs, self.pooling_by_attention.call)
)
summary = self.output_projector(summary)
return summary
90 changes: 90 additions & 0 deletions tests/test_approximators/test_sample.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import keras
from tests.utils import check_combination_simulator_adapter

Expand All @@ -16,3 +17,92 @@ def test_approximator_sample(approximator, simulator, batch_size, adapter):
samples = approximator.sample(num_samples=2, conditions=data)

assert isinstance(samples, dict)


@pytest.mark.parametrize("inference_network_type", ["flow_matching", "diffusion_model"])
@pytest.mark.parametrize("summary_network_type", ["none", "deep_set", "set_transformer", "time_series"])
@pytest.mark.parametrize("method", ["euler", "rk45", "euler_maruyama"])
def test_approximator_sample_with_integration_methods(
inference_network_type, summary_network_type, method, simulator, adapter
):
"""Test approximator sampling with different integration methods and summary networks.

Tests flow matching and diffusion models with different ODE/SDE solvers:
- euler, rk45: Available for both flow matching and diffusion models
- euler_maruyama: Only for diffusion models (stochastic)

Also tests with different summary network types.
"""
batch_size = 8 # Use smaller batch size for faster tests
check_combination_simulator_adapter(simulator, adapter)

# Skip euler_maruyama for flow matching (deterministic model)
if inference_network_type == "flow_matching" and method == "euler_maruyama":
pytest.skip("euler_maruyama is only available for diffusion models")

# Create inference network based on type
if inference_network_type == "flow_matching":
from bayesflow.networks import FlowMatching, MLP

inference_network = FlowMatching(
subnet=MLP(widths=[32, 32]),
integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests
)
elif inference_network_type == "diffusion_model":
from bayesflow.networks import DiffusionModel, MLP

inference_network = DiffusionModel(
subnet=MLP(widths=[32, 32]),
integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests
)
else:
pytest.skip(f"Unsupported inference network type: {inference_network_type}")

# Create summary network based on type
summary_network = None
if summary_network_type != "none":
if summary_network_type == "deep_set":
from bayesflow.networks import DeepSet, MLP

summary_network = DeepSet(subnet=MLP(widths=[16, 16]))
elif summary_network_type == "set_transformer":
from bayesflow.networks import SetTransformer

summary_network = SetTransformer(embed_dims=[16, 16], mlp_widths=[16, 16])
elif summary_network_type == "time_series":
from bayesflow.networks import TimeSeriesNetwork

summary_network = TimeSeriesNetwork(subnet_kwargs={"widths": [16, 16]}, cell_type="lstm")
else:
pytest.skip(f"Unsupported summary network type: {summary_network_type}")

# Update adapter to include summary variables if summary network is present
from bayesflow import ContinuousApproximator

adapter = ContinuousApproximator.build_adapter(
inference_variables=["mean", "std"],
summary_variables=["x"], # Use x as summary variable for testing
)

# Create approximator
from bayesflow import ContinuousApproximator

approximator = ContinuousApproximator(
adapter=adapter, inference_network=inference_network, summary_network=summary_network
)

# Generate test data
num_batches = 2 # Use fewer batches for faster tests
data = simulator.sample((num_batches * batch_size,))

# Build approximator
batch = adapter(data)
batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch)
batch_shapes = keras.tree.map_structure(keras.ops.shape, batch)
approximator.build(batch_shapes)

# Test sampling with the specified method
samples = approximator.sample(num_samples=2, conditions=data, method=method)

# Verify results
assert isinstance(samples, dict)
Loading