Skip to content

Commit b8d870f

Browse files
committed
store summary_dim in summary networks init consistently after output_projector
1 parent 3d53e39 commit b8d870f

File tree

4 files changed

+3
-2
lines changed

4 files changed

+3
-2
lines changed

bayesflow/networks/lstnet/lstnet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def __init__(
3737
**kwargs,
3838
):
3939
super().__init__(**kwargs)
40-
self.summary_dim = summary_dim
4140

4241
# Convolutional backbone -> can be extended with inception-like structure
4342
if not isinstance(filters, (list, tuple)):

bayesflow/networks/transformers/fusion_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __init__(
120120
raise ValueError("Argument `template_dim` should be in ['lstm', 'gru']")
121121

122122
self.output_projector = keras.layers.Dense(summary_dim)
123+
self.summary_dim = summary_dim
123124

124125
def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor:
125126
"""Compresses the input sequence into a summary vector of size `summary_dim`.

bayesflow/networks/transformers/set_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def __init__(
8080
"""
8181

8282
super().__init__(**kwargs)
83-
self.summary_dim = summary_dim
8483

8584
check_lengths_same(embed_dims, num_heads, mlp_depths, mlp_widths)
8685

@@ -126,6 +125,7 @@ def __init__(
126125
)
127126
self.pooling_by_attention = PoolingByMultiHeadAttention(**(global_attention_settings | pooling_settings))
128127
self.output_projector = keras.layers.Dense(summary_dim)
128+
self.summary_dim = summary_dim
129129

130130
def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
131131
"""Compresses the input sequence into a summary vector of size `summary_dim`.

bayesflow/networks/transformers/time_series_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
# Pooling will be applied as a final step to the abstract representations obtained from set attention
104104
self.pooling = keras.layers.GlobalAvgPool1D()
105105
self.output_projector = keras.layers.Dense(summary_dim)
106+
self.summary_dim = summary_dim
106107

107108
self.time_axis = time_axis
108109

0 commit comments

Comments
 (0)