Skip to content

Commit 5c2f390

Browse files
committed
Add (de)serialization code for TimeSeriesTransformer
1 parent 32b5a03 commit 5c2f390

File tree

1 file changed

+43
-3
lines changed

1 file changed

+43
-3
lines changed

bayesflow/networks/transformers/time_series_transformer.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import keras
22

33
from bayesflow.types import Tensor
4-
from bayesflow.utils import check_lengths_same
5-
from bayesflow.utils.serialization import serializable
4+
from bayesflow.utils import check_lengths_same, model_kwargs
5+
from bayesflow.utils.serialization import deserialize, serializable, serialize
66

77
from ..embeddings import Time2Vec, RecurrentEmbedding
88
from ..summary_network import SummaryNetwork
@@ -103,9 +103,22 @@ def __init__(
103103
# Pooling will be applied as a final step to the abstract representations obtained from set attention
104104
self.pooling = keras.layers.GlobalAvgPool1D()
105105
self.output_projector = keras.layers.Dense(summary_dim)
106-
self.summary_dim = summary_dim
107106

107+
# store variables for serialization
108+
self.summary_dim = summary_dim
109+
self.embed_dims = embed_dims
110+
self.num_heads = num_heads
111+
self.mlp_depths = mlp_depths
112+
self.mlp_widths = mlp_widths
113+
self.dropout = dropout
114+
self.mlp_activation = mlp_activation
115+
self.kernel_initializer = kernel_initializer
116+
self.use_bias = use_bias
117+
self.layer_norm = layer_norm
118+
self._time_embedding_arg = time_embedding
119+
self.time_embed_dim = time_embed_dim
108120
self.time_axis = time_axis
121+
self._kwargs = kwargs
109122

110123
def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor:
111124
"""Compresses the input sequence into a summary vector of size `summary_dim`.
@@ -147,3 +160,30 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens
147160
summary = self.pooling(inp)
148161
summary = self.output_projector(summary)
149162
return summary
163+
164+
@classmethod
165+
def from_config(cls, config, custom_objects=None):
166+
return cls(**deserialize(config, custom_objects=custom_objects))
167+
168+
def get_config(self):
169+
base_config = super().get_config()
170+
base_config = model_kwargs(base_config)
171+
172+
config = {
173+
"summary_dim": self.summary_dim,
174+
"embed_dims": self.embed_dims,
175+
"num_heads": self.num_heads,
176+
"mlp_depths": self.mlp_depths,
177+
"mlp_widths": self.mlp_widths,
178+
"dropout": self.dropout,
179+
"mlp_activation": self.mlp_activation,
180+
"kernel_initializer": self.kernel_initializer,
181+
"use_bias": self.use_bias,
182+
"layer_norm": self.layer_norm,
183+
"time_embedding": self._time_embedding_arg,
184+
"time_embed_dim": self.time_embed_dim,
185+
"time_axis": self.time_axis,
186+
**self._kwargs,
187+
}
188+
189+
return base_config | serialize(config)

0 commit comments

Comments
 (0)