diff --git a/bayesflow/networks/transformers/fusion_transformer.py b/bayesflow/networks/transformers/fusion_transformer.py index 1821c25d2..cc5f14e65 100644 --- a/bayesflow/networks/transformers/fusion_transformer.py +++ b/bayesflow/networks/transformers/fusion_transformer.py @@ -2,8 +2,8 @@ from keras import layers from bayesflow.types import Tensor -from bayesflow.utils import check_lengths_same -from bayesflow.utils.serialization import serializable +from bayesflow.utils import check_lengths_same, model_kwargs +from bayesflow.utils.serialization import deserialize, serializable, serialize from ..summary_network import SummaryNetwork @@ -121,6 +121,19 @@ def __init__( self.output_projector = keras.layers.Dense(summary_dim) self.summary_dim = summary_dim + self.embed_dims = embed_dims + self.num_heads = num_heads + self.mlp_depths = mlp_depths + self.mlp_widths = mlp_widths + self.dropout = dropout + self.mlp_activation = mlp_activation + self.kernel_initializer = kernel_initializer + self.use_bias = use_bias + self.layer_norm = layer_norm + self.template_type = template_type + self.bidirectional = bidirectional + self.template_dim = template_dim + self._kwargs = kwargs def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor: """Compresses the input sequence into a summary vector of size `summary_dim`. @@ -151,3 +164,30 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens summary = self.attention_blocks[-1](keras.ops.expand_dims(template, axis=1), rep, training=training, **kwargs) summary = self.output_projector(keras.ops.squeeze(summary, axis=1)) return summary + + @classmethod + def from_config(cls, config, custom_objects=None): + return cls(**deserialize(config, custom_objects=custom_objects)) + + def get_config(self): + base_config = super().get_config() + base_config = model_kwargs(base_config) + + config = { + "summary_dim": self.summary_dim, + "embed_dims": self.embed_dims, + "num_heads": self.num_heads, + "mlp_depths": self.mlp_depths, + "mlp_widths": self.mlp_widths, + "dropout": self.dropout, + "mlp_activation": self.mlp_activation, + "kernel_initializer": self.kernel_initializer, + "use_bias": self.use_bias, + "layer_norm": self.layer_norm, + "template_type": self.template_type, + "bidirectional": self.bidirectional, + "template_dim": self.template_dim, + **self._kwargs, + } + + return base_config | serialize(config) diff --git a/bayesflow/networks/transformers/time_series_transformer.py b/bayesflow/networks/transformers/time_series_transformer.py index dac0e52b7..4e5599ae6 100644 --- a/bayesflow/networks/transformers/time_series_transformer.py +++ b/bayesflow/networks/transformers/time_series_transformer.py @@ -1,8 +1,8 @@ import keras from bayesflow.types import Tensor -from bayesflow.utils import check_lengths_same -from bayesflow.utils.serialization import serializable +from bayesflow.utils import check_lengths_same, model_kwargs +from bayesflow.utils.serialization import deserialize, serializable, serialize from ..embeddings import Time2Vec, RecurrentEmbedding from ..summary_network import SummaryNetwork @@ -103,9 +103,22 @@ def __init__( # Pooling will be applied as a final step to the abstract representations obtained from set attention self.pooling = keras.layers.GlobalAvgPool1D() self.output_projector = keras.layers.Dense(summary_dim) - self.summary_dim = summary_dim + # store variables for serialization + self.summary_dim = summary_dim + self.embed_dims = embed_dims + self.num_heads = num_heads + self.mlp_depths = mlp_depths + self.mlp_widths = mlp_widths + self.dropout = dropout + self.mlp_activation = mlp_activation + self.kernel_initializer = kernel_initializer + self.use_bias = use_bias + self.layer_norm = layer_norm + self._time_embedding_arg = time_embedding + self.time_embed_dim = time_embed_dim self.time_axis = time_axis + self._kwargs = kwargs def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor: """Compresses the input sequence into a summary vector of size `summary_dim`. @@ -147,3 +160,30 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens summary = self.pooling(inp) summary = self.output_projector(summary) return summary + + @classmethod + def from_config(cls, config, custom_objects=None): + return cls(**deserialize(config, custom_objects=custom_objects)) + + def get_config(self): + base_config = super().get_config() + base_config = model_kwargs(base_config) + + config = { + "summary_dim": self.summary_dim, + "embed_dims": self.embed_dims, + "num_heads": self.num_heads, + "mlp_depths": self.mlp_depths, + "mlp_widths": self.mlp_widths, + "dropout": self.dropout, + "mlp_activation": self.mlp_activation, + "kernel_initializer": self.kernel_initializer, + "use_bias": self.use_bias, + "layer_norm": self.layer_norm, + "time_embedding": self._time_embedding_arg, + "time_embed_dim": self.time_embed_dim, + "time_axis": self.time_axis, + **self._kwargs, + } + + return base_config | serialize(config) diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index b4ad8df99..84c011812 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -119,6 +119,20 @@ def time_series_network(summary_dim): return TimeSeriesNetwork(summary_dim=summary_dim) +@pytest.fixture(scope="function") +def time_series_transformer(summary_dim): + from bayesflow.networks import TimeSeriesTransformer + + return TimeSeriesTransformer(summary_dim=summary_dim) + + +@pytest.fixture(scope="function") +def fusion_transformer(summary_dim): + from bayesflow.networks import FusionTransformer + + return FusionTransformer(summary_dim=summary_dim) + + @pytest.fixture(scope="function") def set_transformer(summary_dim): from bayesflow.networks import SetTransformer @@ -133,7 +147,17 @@ def deep_set(summary_dim): return DeepSet(summary_dim=summary_dim) -@pytest.fixture(params=[None, "time_series_network", "set_transformer", "deep_set"], scope="function") +@pytest.fixture( + params=[ + None, + "time_series_network", + "time_series_transformer", + "fusion_transformer", + "set_transformer", + "deep_set", + ], + scope="function", +) def summary_network(request, summary_dim): if request.param is None: return None