Skip to content

Commit 2917f68

Browse files
committed
Fix loading serialized transformers
1 parent e69fdd7 commit 2917f68

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

bayesflow/networks/transformers/fusion_transformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from bayesflow.types import Tensor
66
from bayesflow.utils import check_lengths_same
7+
from bayesflow.utils.decorators import sanitize_input_shape
78

89
from ..summary_network import SummaryNetwork
910

@@ -151,3 +152,8 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens
151152
summary = self.attention_blocks[-1](keras.ops.expand_dims(template, axis=1), rep, training=training, **kwargs)
152153
summary = self.output_projector(keras.ops.squeeze(summary, axis=1))
153154
return summary
155+
156+
@sanitize_input_shape
157+
def build(self, input_shape):
158+
super().build(input_shape)
159+
self.call(keras.ops.zeros(input_shape))

bayesflow/networks/transformers/set_transformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from bayesflow.types import Tensor
55
from bayesflow.utils import check_lengths_same
6+
from bayesflow.utils.decorators import sanitize_input_shape
67

78
from ..summary_network import SummaryNetwork
89

@@ -150,3 +151,8 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
150151
summary = self.pooling_by_attention(summary, training=training, **kwargs)
151152
summary = self.output_projector(summary)
152153
return summary
154+
155+
@sanitize_input_shape
156+
def build(self, input_shape):
157+
super().build(input_shape)
158+
self.call(keras.ops.zeros(input_shape))

bayesflow/networks/transformers/time_series_transformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from bayesflow.types import Tensor
55
from bayesflow.utils import check_lengths_same
6+
from bayesflow.utils.decorators import sanitize_input_shape
67

78
from ..embeddings import Time2Vec, RecurrentEmbedding
89
from ..summary_network import SummaryNetwork
@@ -147,3 +148,8 @@ def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tens
147148
summary = self.pooling(inp)
148149
summary = self.output_projector(summary)
149150
return summary
151+
152+
@sanitize_input_shape
153+
def build(self, input_shape):
154+
super().build(input_shape)
155+
self.call(keras.ops.zeros(input_shape))

0 commit comments

Comments
 (0)