|
1 | 1 | import keras |
2 | 2 |
|
3 | 3 | from bayesflow.types import Tensor |
4 | | -from bayesflow.utils import check_lengths_same |
5 | | -from bayesflow.utils.serialization import serializable |
| 4 | +from bayesflow.utils import check_lengths_same, model_kwargs |
| 5 | +from bayesflow.utils.serialization import deserialize, serializable, serialize |
6 | 6 |
|
7 | 7 | from ..embeddings import Time2Vec, RecurrentEmbedding |
8 | 8 | from ..summary_network import SummaryNetwork |
@@ -103,9 +103,22 @@ def __init__( |
103 | 103 | # Pooling will be applied as a final step to the abstract representations obtained from set attention |
104 | 104 | self.pooling = keras.layers.GlobalAvgPool1D() |
105 | 105 | self.output_projector = keras.layers.Dense(summary_dim) |
106 | | - self.summary_dim = summary_dim |
107 | 106 |
|
| 107 | + # store variables for serialization |
| 108 | + self.summary_dim = summary_dim |
| 109 | + self.embed_dims = embed_dims |
| 110 | + self.num_heads = num_heads |
| 111 | + self.mlp_depths = mlp_depths |
| 112 | + self.mlp_widths = mlp_widths |
| 113 | + self.dropout = dropout |
| 114 | + self.mlp_activation = mlp_activation |
| 115 | + self.kernel_initializer = kernel_initializer |
| 116 | + self.use_bias = use_bias |
| 117 | + self.layer_norm = layer_norm |
| 118 | + self._time_embedding_arg = time_embedding |
| 119 | + self.time_embed_dim = time_embed_dim |
108 | 120 | self.time_axis = time_axis |
| 121 | + self._kwargs = kwargs |
109 | 122 |
|
110 | 123 | def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor: |
111 | 124 | """Compresses the input sequence into a summary vector of size `summary_dim`. |
@@ -147,3 +160,30 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens |
147 | 160 | summary = self.pooling(inp) |
148 | 161 | summary = self.output_projector(summary) |
149 | 162 | return summary |
| 163 | + |
| 164 | + @classmethod |
| 165 | + def from_config(cls, config, custom_objects=None): |
| 166 | + return cls(**deserialize(config, custom_objects=custom_objects)) |
| 167 | + |
| 168 | + def get_config(self): |
| 169 | + base_config = super().get_config() |
| 170 | + base_config = model_kwargs(base_config) |
| 171 | + |
| 172 | + config = { |
| 173 | + "summary_dim": self.summary_dim, |
| 174 | + "embed_dims": self.embed_dims, |
| 175 | + "num_heads": self.num_heads, |
| 176 | + "mlp_depths": self.mlp_depths, |
| 177 | + "mlp_widths": self.mlp_widths, |
| 178 | + "dropout": self.dropout, |
| 179 | + "mlp_activation": self.mlp_activation, |
| 180 | + "kernel_initializer": self.kernel_initializer, |
| 181 | + "use_bias": self.use_bias, |
| 182 | + "layer_norm": self.layer_norm, |
| 183 | + "time_embedding": self._time_embedding_arg, |
| 184 | + "time_embed_dim": self.time_embed_dim, |
| 185 | + "time_axis": self.time_axis, |
| 186 | + **self._kwargs, |
| 187 | + } |
| 188 | + |
| 189 | + return base_config | serialize(config) |
0 commit comments