Skip to content

Commit a814aa5

Browse files
committed
Add serialization and tests for FusionTransformer
1 parent 5c2f390 commit a814aa5

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

bayesflow/networks/transformers/fusion_transformer.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from keras import layers
33

44
from bayesflow.types import Tensor
5-
from bayesflow.utils import check_lengths_same
6-
from bayesflow.utils.serialization import serializable
5+
from bayesflow.utils import check_lengths_same, model_kwargs
6+
from bayesflow.utils.serialization import deserialize, serializable, serialize
77

88
from ..summary_network import SummaryNetwork
99

@@ -121,6 +121,19 @@ def __init__(
121121

122122
self.output_projector = keras.layers.Dense(summary_dim)
123123
self.summary_dim = summary_dim
124+
self.embed_dims = embed_dims
125+
self.num_heads = num_heads
126+
self.mlp_depths = mlp_depths
127+
self.mlp_widths = mlp_widths
128+
self.dropout = dropout
129+
self.mlp_activation = mlp_activation
130+
self.kernel_initializer = kernel_initializer
131+
self.use_bias = use_bias
132+
self.layer_norm = layer_norm
133+
self.template_type = template_type
134+
self.bidirectional = bidirectional
135+
self.template_dim = template_dim
136+
self._kwargs = kwargs
124137

125138
def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor:
126139
"""Compresses the input sequence into a summary vector of size `summary_dim`.
@@ -151,3 +164,30 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens
151164
summary = self.attention_blocks[-1](keras.ops.expand_dims(template, axis=1), rep, training=training, **kwargs)
152165
summary = self.output_projector(keras.ops.squeeze(summary, axis=1))
153166
return summary
167+
168+
@classmethod
169+
def from_config(cls, config, custom_objects=None):
170+
return cls(**deserialize(config, custom_objects=custom_objects))
171+
172+
def get_config(self):
173+
base_config = super().get_config()
174+
base_config = model_kwargs(base_config)
175+
176+
config = {
177+
"summary_dim": self.summary_dim,
178+
"embed_dims": self.embed_dims,
179+
"num_heads": self.num_heads,
180+
"mlp_depths": self.mlp_depths,
181+
"mlp_widths": self.mlp_widths,
182+
"dropout": self.dropout,
183+
"mlp_activation": self.mlp_activation,
184+
"kernel_initializer": self.kernel_initializer,
185+
"use_bias": self.use_bias,
186+
"layer_norm": self.layer_norm,
187+
"template_type": self.template_type,
188+
"bidirectional": self.bidirectional,
189+
"template_dim": self.template_dim,
190+
**self._kwargs,
191+
}
192+
193+
return base_config | serialize(config)

tests/test_networks/conftest.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,13 @@ def time_series_transformer(summary_dim):
126126
return TimeSeriesTransformer(summary_dim=summary_dim)
127127

128128

129+
@pytest.fixture(scope="function")
130+
def fusion_transformer(summary_dim):
131+
from bayesflow.networks import FusionTransformer
132+
133+
return FusionTransformer(summary_dim=summary_dim)
134+
135+
129136
@pytest.fixture(scope="function")
130137
def set_transformer(summary_dim):
131138
from bayesflow.networks import SetTransformer
@@ -141,7 +148,15 @@ def deep_set(summary_dim):
141148

142149

143150
@pytest.fixture(
144-
params=[None, "time_series_network", "time_series_transformer", "set_transformer", "deep_set"], scope="function"
151+
params=[
152+
None,
153+
"time_series_network",
154+
"time_series_transformer",
155+
"fusion_transformer",
156+
"set_transformer",
157+
"deep_set",
158+
],
159+
scope="function",
145160
)
146161
def summary_network(request, summary_dim):
147162
if request.param is None:

0 commit comments

Comments
 (0)