|
2 | 2 | from keras import layers |
3 | 3 |
|
4 | 4 | from bayesflow.types import Tensor |
5 | | -from bayesflow.utils import check_lengths_same |
6 | | -from bayesflow.utils.serialization import serializable |
| 5 | +from bayesflow.utils import check_lengths_same, model_kwargs |
| 6 | +from bayesflow.utils.serialization import deserialize, serializable, serialize |
7 | 7 |
|
8 | 8 | from ..summary_network import SummaryNetwork |
9 | 9 |
|
@@ -121,6 +121,19 @@ def __init__( |
121 | 121 |
|
122 | 122 | self.output_projector = keras.layers.Dense(summary_dim) |
123 | 123 | self.summary_dim = summary_dim |
| 124 | + self.embed_dims = embed_dims |
| 125 | + self.num_heads = num_heads |
| 126 | + self.mlp_depths = mlp_depths |
| 127 | + self.mlp_widths = mlp_widths |
| 128 | + self.dropout = dropout |
| 129 | + self.mlp_activation = mlp_activation |
| 130 | + self.kernel_initializer = kernel_initializer |
| 131 | + self.use_bias = use_bias |
| 132 | + self.layer_norm = layer_norm |
| 133 | + self.template_type = template_type |
| 134 | + self.bidirectional = bidirectional |
| 135 | + self.template_dim = template_dim |
| 136 | + self._kwargs = kwargs |
124 | 137 |
|
125 | 138 | def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor: |
126 | 139 | """Compresses the input sequence into a summary vector of size `summary_dim`. |
@@ -151,3 +164,30 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens |
151 | 164 | summary = self.attention_blocks[-1](keras.ops.expand_dims(template, axis=1), rep, training=training, **kwargs) |
152 | 165 | summary = self.output_projector(keras.ops.squeeze(summary, axis=1)) |
153 | 166 | return summary |
| 167 | + |
| 168 | + @classmethod |
| 169 | + def from_config(cls, config, custom_objects=None): |
| 170 | + return cls(**deserialize(config, custom_objects=custom_objects)) |
| 171 | + |
| 172 | + def get_config(self): |
| 173 | + base_config = super().get_config() |
| 174 | + base_config = model_kwargs(base_config) |
| 175 | + |
| 176 | + config = { |
| 177 | + "summary_dim": self.summary_dim, |
| 178 | + "embed_dims": self.embed_dims, |
| 179 | + "num_heads": self.num_heads, |
| 180 | + "mlp_depths": self.mlp_depths, |
| 181 | + "mlp_widths": self.mlp_widths, |
| 182 | + "dropout": self.dropout, |
| 183 | + "mlp_activation": self.mlp_activation, |
| 184 | + "kernel_initializer": self.kernel_initializer, |
| 185 | + "use_bias": self.use_bias, |
| 186 | + "layer_norm": self.layer_norm, |
| 187 | + "template_type": self.template_type, |
| 188 | + "bidirectional": self.bidirectional, |
| 189 | + "template_dim": self.template_dim, |
| 190 | + **self._kwargs, |
| 191 | + } |
| 192 | + |
| 193 | + return base_config | serialize(config) |
0 commit comments