Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions bayesflow/networks/transformers/fusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from keras import layers

from bayesflow.types import Tensor
from bayesflow.utils import check_lengths_same
from bayesflow.utils.serialization import serializable
from bayesflow.utils import check_lengths_same, model_kwargs
from bayesflow.utils.serialization import deserialize, serializable, serialize

from ..summary_network import SummaryNetwork

Expand Down Expand Up @@ -121,6 +121,19 @@ def __init__(

self.output_projector = keras.layers.Dense(summary_dim)
self.summary_dim = summary_dim
self.embed_dims = embed_dims
self.num_heads = num_heads
self.mlp_depths = mlp_depths
self.mlp_widths = mlp_widths
self.dropout = dropout
self.mlp_activation = mlp_activation
self.kernel_initializer = kernel_initializer
self.use_bias = use_bias
self.layer_norm = layer_norm
self.template_type = template_type
self.bidirectional = bidirectional
self.template_dim = template_dim
self._kwargs = kwargs

def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor:
"""Compresses the input sequence into a summary vector of size `summary_dim`.
Expand Down Expand Up @@ -151,3 +164,30 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens
summary = self.attention_blocks[-1](keras.ops.expand_dims(template, axis=1), rep, training=training, **kwargs)
summary = self.output_projector(keras.ops.squeeze(summary, axis=1))
return summary

@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**deserialize(config, custom_objects=custom_objects))

def get_config(self):
base_config = super().get_config()
base_config = model_kwargs(base_config)

config = {
"summary_dim": self.summary_dim,
"embed_dims": self.embed_dims,
"num_heads": self.num_heads,
"mlp_depths": self.mlp_depths,
"mlp_widths": self.mlp_widths,
"dropout": self.dropout,
"mlp_activation": self.mlp_activation,
"kernel_initializer": self.kernel_initializer,
"use_bias": self.use_bias,
"layer_norm": self.layer_norm,
"template_type": self.template_type,
"bidirectional": self.bidirectional,
"template_dim": self.template_dim,
**self._kwargs,
}

return base_config | serialize(config)
46 changes: 43 additions & 3 deletions bayesflow/networks/transformers/time_series_transformer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import keras

from bayesflow.types import Tensor
from bayesflow.utils import check_lengths_same
from bayesflow.utils.serialization import serializable
from bayesflow.utils import check_lengths_same, model_kwargs
from bayesflow.utils.serialization import deserialize, serializable, serialize

from ..embeddings import Time2Vec, RecurrentEmbedding
from ..summary_network import SummaryNetwork
Expand Down Expand Up @@ -103,9 +103,22 @@ def __init__(
# Pooling will be applied as a final step to the abstract representations obtained from set attention
self.pooling = keras.layers.GlobalAvgPool1D()
self.output_projector = keras.layers.Dense(summary_dim)
self.summary_dim = summary_dim

# store variables for serialization
self.summary_dim = summary_dim
self.embed_dims = embed_dims
self.num_heads = num_heads
self.mlp_depths = mlp_depths
self.mlp_widths = mlp_widths
self.dropout = dropout
self.mlp_activation = mlp_activation
self.kernel_initializer = kernel_initializer
self.use_bias = use_bias
self.layer_norm = layer_norm
self._time_embedding_arg = time_embedding
self.time_embed_dim = time_embed_dim
self.time_axis = time_axis
self._kwargs = kwargs

def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor:
"""Compresses the input sequence into a summary vector of size `summary_dim`.
Expand Down Expand Up @@ -147,3 +160,30 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens
summary = self.pooling(inp)
summary = self.output_projector(summary)
return summary

@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**deserialize(config, custom_objects=custom_objects))

def get_config(self):
base_config = super().get_config()
base_config = model_kwargs(base_config)

config = {
"summary_dim": self.summary_dim,
"embed_dims": self.embed_dims,
"num_heads": self.num_heads,
"mlp_depths": self.mlp_depths,
"mlp_widths": self.mlp_widths,
"dropout": self.dropout,
"mlp_activation": self.mlp_activation,
"kernel_initializer": self.kernel_initializer,
"use_bias": self.use_bias,
"layer_norm": self.layer_norm,
"time_embedding": self._time_embedding_arg,
"time_embed_dim": self.time_embed_dim,
"time_axis": self.time_axis,
**self._kwargs,
}

return base_config | serialize(config)
26 changes: 25 additions & 1 deletion tests/test_networks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ def time_series_network(summary_dim):
return TimeSeriesNetwork(summary_dim=summary_dim)


@pytest.fixture(scope="function")
def time_series_transformer(summary_dim):
from bayesflow.networks import TimeSeriesTransformer

return TimeSeriesTransformer(summary_dim=summary_dim)


@pytest.fixture(scope="function")
def fusion_transformer(summary_dim):
from bayesflow.networks import FusionTransformer

return FusionTransformer(summary_dim=summary_dim)


@pytest.fixture(scope="function")
def set_transformer(summary_dim):
from bayesflow.networks import SetTransformer
Expand All @@ -133,7 +147,17 @@ def deep_set(summary_dim):
return DeepSet(summary_dim=summary_dim)


@pytest.fixture(params=[None, "time_series_network", "set_transformer", "deep_set"], scope="function")
@pytest.fixture(
params=[
None,
"time_series_network",
"time_series_transformer",
"fusion_transformer",
"set_transformer",
"deep_set",
],
scope="function",
)
def summary_network(request, summary_dim):
if request.param is None:
return None
Expand Down