Skip to content

Commit cede7f8

Browse files
authored
Fix SetTransformer kwargs handling (#574)
* fix kwargs in sample and set transformer
1 parent 9abb126 commit cede7f8

File tree

3 files changed

+93
-3
lines changed

3 files changed

+93
-3
lines changed

bayesflow/networks/transformers/mab.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from bayesflow.networks import MLP
55
from bayesflow.types import Tensor
6-
from bayesflow.utils import layer_kwargs
6+
from bayesflow.utils import layer_kwargs, filter_kwargs
77
from bayesflow.utils.decorators import sanitize_input_shape
88
from bayesflow.utils.serialization import serializable
99

@@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) -
111111
"""
112112

113113
h = self.input_projector(seq_x) + self.attention(
114-
query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs
114+
query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call)
115115
)
116116
if self.ln_pre is not None:
117117
h = self.ln_pre(h, training=training)

bayesflow/networks/transformers/set_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
147147
out : Tensor
148148
Output of shape (batch_size, set_size, output_dim)
149149
"""
150-
summary = self.attention_blocks(input_set, training=training, **kwargs)
150+
summary = self.attention_blocks(input_set, training=training)
151151
summary = self.pooling_by_attention(summary, training=training, **kwargs)
152152
summary = self.output_projector(summary)
153153
return summary

tests/test_approximators/test_sample.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
import keras
23
from tests.utils import check_combination_simulator_adapter
34

@@ -16,3 +17,92 @@ def test_approximator_sample(approximator, simulator, batch_size, adapter):
1617
samples = approximator.sample(num_samples=2, conditions=data)
1718

1819
assert isinstance(samples, dict)
20+
21+
22+
@pytest.mark.parametrize("inference_network_type", ["flow_matching", "diffusion_model"])
23+
@pytest.mark.parametrize("summary_network_type", ["none", "deep_set", "set_transformer", "time_series"])
24+
@pytest.mark.parametrize("method", ["euler", "rk45", "euler_maruyama"])
25+
def test_approximator_sample_with_integration_methods(
26+
inference_network_type, summary_network_type, method, simulator, adapter
27+
):
28+
"""Test approximator sampling with different integration methods and summary networks.
29+
30+
Tests flow matching and diffusion models with different ODE/SDE solvers:
31+
- euler, rk45: Available for both flow matching and diffusion models
32+
- euler_maruyama: Only for diffusion models (stochastic)
33+
34+
Also tests with different summary network types.
35+
"""
36+
batch_size = 8 # Use smaller batch size for faster tests
37+
check_combination_simulator_adapter(simulator, adapter)
38+
39+
# Skip euler_maruyama for flow matching (deterministic model)
40+
if inference_network_type == "flow_matching" and method == "euler_maruyama":
41+
pytest.skip("euler_maruyama is only available for diffusion models")
42+
43+
# Create inference network based on type
44+
if inference_network_type == "flow_matching":
45+
from bayesflow.networks import FlowMatching, MLP
46+
47+
inference_network = FlowMatching(
48+
subnet=MLP(widths=[32, 32]),
49+
integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests
50+
)
51+
elif inference_network_type == "diffusion_model":
52+
from bayesflow.networks import DiffusionModel, MLP
53+
54+
inference_network = DiffusionModel(
55+
subnet=MLP(widths=[32, 32]),
56+
integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests
57+
)
58+
else:
59+
pytest.skip(f"Unsupported inference network type: {inference_network_type}")
60+
61+
# Create summary network based on type
62+
summary_network = None
63+
if summary_network_type != "none":
64+
if summary_network_type == "deep_set":
65+
from bayesflow.networks import DeepSet, MLP
66+
67+
summary_network = DeepSet(subnet=MLP(widths=[16, 16]))
68+
elif summary_network_type == "set_transformer":
69+
from bayesflow.networks import SetTransformer
70+
71+
summary_network = SetTransformer(embed_dims=[16, 16], mlp_widths=[16, 16])
72+
elif summary_network_type == "time_series":
73+
from bayesflow.networks import TimeSeriesNetwork
74+
75+
summary_network = TimeSeriesNetwork(subnet_kwargs={"widths": [16, 16]}, cell_type="lstm")
76+
else:
77+
pytest.skip(f"Unsupported summary network type: {summary_network_type}")
78+
79+
# Update adapter to include summary variables if summary network is present
80+
from bayesflow import ContinuousApproximator
81+
82+
adapter = ContinuousApproximator.build_adapter(
83+
inference_variables=["mean", "std"],
84+
summary_variables=["x"], # Use x as summary variable for testing
85+
)
86+
87+
# Create approximator
88+
from bayesflow import ContinuousApproximator
89+
90+
approximator = ContinuousApproximator(
91+
adapter=adapter, inference_network=inference_network, summary_network=summary_network
92+
)
93+
94+
# Generate test data
95+
num_batches = 2 # Use fewer batches for faster tests
96+
data = simulator.sample((num_batches * batch_size,))
97+
98+
# Build approximator
99+
batch = adapter(data)
100+
batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch)
101+
batch_shapes = keras.tree.map_structure(keras.ops.shape, batch)
102+
approximator.build(batch_shapes)
103+
104+
# Test sampling with the specified method
105+
samples = approximator.sample(num_samples=2, conditions=data, method=method)
106+
107+
# Verify results
108+
assert isinstance(samples, dict)

0 commit comments

Comments
 (0)