Skip to content

Commit 2002aad

Browse files
committed
Revert changes in time series / fusion transformers
1 parent 79cf26f commit 2002aad

File tree

4 files changed

+11
-127
lines changed

4 files changed

+11
-127
lines changed

bayesflow/networks/coupling_flow/permutations/random.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@ def build(self, xz_shape: Shape, **kwargs) -> None:
1515

1616
self.forward_indices = self.add_weight(
1717
shape=(xz_shape[-1],),
18-
initializer=keras.initializers.Constant(forward_indices),
18+
# Best practice: https://github.com/keras-team/keras/pull/20457#discussion_r1832081248
19+
initializer=keras.initializers.get(forward_indices),
1920
trainable=False,
2021
dtype="int",
2122
)
2223

2324
self.inverse_indices = self.add_weight(
2425
shape=(xz_shape[-1],),
25-
initializer=keras.initializers.Constant(inverse_indices),
26+
# Best practice: https://github.com/keras-team/keras/pull/20457#discussion_r1832081248
27+
initializer=keras.initializers.get(inverse_indices),
2628
trainable=False,
2729
dtype="int",
2830
)

bayesflow/networks/transformers/fusion_transformer.py

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from keras import layers
33

44
from bayesflow.types import Tensor
5-
from bayesflow.utils import check_lengths_same, model_kwargs
6-
from bayesflow.utils.serialization import deserialize, serializable, serialize
5+
from bayesflow.utils import check_lengths_same
6+
from bayesflow.utils.serialization import serializable
77

88
from ..summary_network import SummaryNetwork
99

@@ -121,19 +121,6 @@ def __init__(
121121

122122
self.output_projector = keras.layers.Dense(summary_dim)
123123
self.summary_dim = summary_dim
124-
self.embed_dims = embed_dims
125-
self.num_heads = num_heads
126-
self.mlp_depths = mlp_depths
127-
self.mlp_widths = mlp_widths
128-
self.dropout = dropout
129-
self.mlp_activation = mlp_activation
130-
self.kernel_initializer = kernel_initializer
131-
self.use_bias = use_bias
132-
self.layer_norm = layer_norm
133-
self.template_type = template_type
134-
self.bidirectional = bidirectional
135-
self.template_dim = template_dim
136-
self._kwargs = kwargs
137124

138125
def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor:
139126
"""Compresses the input sequence into a summary vector of size `summary_dim`.
@@ -164,30 +151,3 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens
164151
summary = self.attention_blocks[-1](keras.ops.expand_dims(template, axis=1), rep, training=training, **kwargs)
165152
summary = self.output_projector(keras.ops.squeeze(summary, axis=1))
166153
return summary
167-
168-
@classmethod
169-
def from_config(cls, config, custom_objects=None):
170-
return cls(**deserialize(config, custom_objects=custom_objects))
171-
172-
def get_config(self):
173-
base_config = super().get_config()
174-
base_config = model_kwargs(base_config)
175-
176-
config = {
177-
"summary_dim": self.summary_dim,
178-
"embed_dims": self.embed_dims,
179-
"num_heads": self.num_heads,
180-
"mlp_depths": self.mlp_depths,
181-
"mlp_widths": self.mlp_widths,
182-
"dropout": self.dropout,
183-
"mlp_activation": self.mlp_activation,
184-
"kernel_initializer": self.kernel_initializer,
185-
"use_bias": self.use_bias,
186-
"layer_norm": self.layer_norm,
187-
"template_type": self.template_type,
188-
"bidirectional": self.bidirectional,
189-
"template_dim": self.template_dim,
190-
**self._kwargs,
191-
}
192-
193-
return base_config | serialize(config)

bayesflow/networks/transformers/set_transformer.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import keras
22

33
from bayesflow.types import Tensor
4-
from bayesflow.utils import check_lengths_same, model_kwargs
5-
from bayesflow.utils.serialization import deserialize, serializable, serialize
4+
from bayesflow.utils import check_lengths_same
5+
from bayesflow.utils.serialization import serializable
66

77
from ..summary_network import SummaryNetwork
88

@@ -129,18 +129,6 @@ def __init__(
129129
self.output_projector = keras.layers.Dense(summary_dim, name="output_projector")
130130

131131
self.summary_dim = summary_dim
132-
self.embed_dims = embed_dims
133-
self.num_heads = num_heads
134-
self.mlp_depths = mlp_depths
135-
self.mlp_widths = mlp_widths
136-
self.num_seeds = num_seeds
137-
self.dropout = dropout
138-
self.mlp_activation = mlp_activation
139-
self.kernel_initializer = kernel_initializer
140-
self.use_bias = use_bias
141-
self.layer_norm = layer_norm
142-
self.num_inducing_points = num_inducing_points
143-
self.seed_dim = seed_dim
144132

145133
def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
146134
"""Compresses the input sequence into a summary vector of size `summary_dim`.
@@ -165,29 +153,3 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
165153
summary = self.pooling_by_attention(summary, training=training, **kwargs)
166154
summary = self.output_projector(summary)
167155
return summary
168-
169-
@classmethod
170-
def from_config(cls, config, custom_objects=None):
171-
return cls(**deserialize(config, custom_objects=custom_objects))
172-
173-
def get_config(self):
174-
base_config = super().get_config()
175-
base_config = model_kwargs(base_config)
176-
177-
config = {
178-
"summary_dim": self.summary_dim,
179-
"embed_dims": self.embed_dims,
180-
"num_heads": self.num_heads,
181-
"mlp_depths": self.mlp_depths,
182-
"mlp_widths": self.mlp_widths,
183-
"num_seeds": self.num_seeds,
184-
"dropout": self.dropout,
185-
"mlp_activation": self.mlp_activation,
186-
"kernel_initializer": self.kernel_initializer,
187-
"use_bias": self.use_bias,
188-
"layer_norm": self.layer_norm,
189-
"num_inducing_points": self.num_inducing_points,
190-
"seed_dim": self.seed_dim,
191-
}
192-
193-
return base_config | serialize(config)

bayesflow/networks/transformers/time_series_transformer.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import keras
22

33
from bayesflow.types import Tensor
4-
from bayesflow.utils import check_lengths_same, model_kwargs
5-
from bayesflow.utils.serialization import deserialize, serializable, serialize
4+
from bayesflow.utils import check_lengths_same
5+
from bayesflow.utils.serialization import serializable
66

77
from ..embeddings import Time2Vec, RecurrentEmbedding
88
from ..summary_network import SummaryNetwork
@@ -103,22 +103,9 @@ def __init__(
103103
# Pooling will be applied as a final step to the abstract representations obtained from set attention
104104
self.pooling = keras.layers.GlobalAvgPool1D()
105105
self.output_projector = keras.layers.Dense(summary_dim)
106-
107-
# store variables for serialization
108106
self.summary_dim = summary_dim
109-
self.embed_dims = embed_dims
110-
self.num_heads = num_heads
111-
self.mlp_depths = mlp_depths
112-
self.mlp_widths = mlp_widths
113-
self.dropout = dropout
114-
self.mlp_activation = mlp_activation
115-
self.kernel_initializer = kernel_initializer
116-
self.use_bias = use_bias
117-
self.layer_norm = layer_norm
118-
self._time_embedding_arg = time_embedding
119-
self.time_embed_dim = time_embed_dim
107+
120108
self.time_axis = time_axis
121-
self._kwargs = kwargs
122109

123110
def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor:
124111
"""Compresses the input sequence into a summary vector of size `summary_dim`.
@@ -160,30 +147,3 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens
160147
summary = self.pooling(inp)
161148
summary = self.output_projector(summary)
162149
return summary
163-
164-
@classmethod
165-
def from_config(cls, config, custom_objects=None):
166-
return cls(**deserialize(config, custom_objects=custom_objects))
167-
168-
def get_config(self):
169-
base_config = super().get_config()
170-
base_config = model_kwargs(base_config)
171-
172-
config = {
173-
"summary_dim": self.summary_dim,
174-
"embed_dims": self.embed_dims,
175-
"num_heads": self.num_heads,
176-
"mlp_depths": self.mlp_depths,
177-
"mlp_widths": self.mlp_widths,
178-
"dropout": self.dropout,
179-
"mlp_activation": self.mlp_activation,
180-
"kernel_initializer": self.kernel_initializer,
181-
"use_bias": self.use_bias,
182-
"layer_norm": self.layer_norm,
183-
"time_embedding": self._time_embedding_arg,
184-
"time_embed_dim": self.time_embed_dim,
185-
"time_axis": self.time_axis,
186-
**self._kwargs,
187-
}
188-
189-
return base_config | serialize(config)

0 commit comments

Comments
 (0)