Skip to content

Commit 09798de

Browse files
committed
Initialize projectors for invariant and equivariant DeepSet layers
1 parent 7c9efd4 commit 09798de

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

bayesflow/networks/deep_set/equivariant_layer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,13 @@ def __init__(
8888
# Fully connected net + residual connection for an equivariant transform applied to each set member
8989
self.input_projector = layers.Dense(mlp_widths_equivariant[-1])
9090
self.equivariant_fc = MLP(
91-
mlp_widths_equivariant,
91+
mlp_widths_equivariant[:-1],
9292
dropout=dropout,
9393
activation=activation,
9494
kernel_initializer=kernel_initializer,
9595
spectral_normalization=spectral_normalization,
9696
)
97+
self.out_fc_projector = keras.layers.Dense(mlp_widths_equivariant[-1], kernel_initializer=kernel_initializer)
9798

9899
self.layer_norm = layers.LayerNormalization() if layer_norm else None
99100

@@ -140,7 +141,7 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
140141
out_fc = self.equivariant_fc(output_set, training=training)
141142
out_projected = self.out_fc_projector(out_fc)
142143
output_set = input_set + out_projected
143-
# output_set = input_set + self.equivariant_fc(output_set, training=training)
144+
144145
if self.layer_norm is not None:
145146
output_set = self.layer_norm(output_set, training=training)
146147

bayesflow/networks/deep_set/invariant_layer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,22 @@ def __init__(
6868

6969
# Inner fully connected net for sum decomposition: inner( pooling( inner(set) ) )
7070
self.inner_fc = MLP(
71-
mlp_widths_inner,
71+
mlp_widths_inner[:-1],
7272
dropout=dropout,
7373
activation=activation,
7474
kernel_initializer=kernel_initializer,
7575
spectral_normalization=spectral_normalization,
7676
)
77+
self.inner_projector = keras.layers.Dense(mlp_widths_inner[-1], kernel_initializer=kernel_initializer)
7778

7879
self.outer_fc = MLP(
79-
mlp_widths_outer,
80+
mlp_widths_outer[:-1],
8081
dropout=dropout,
8182
activation=activation,
8283
kernel_initializer=kernel_initializer,
8384
spectral_normalization=spectral_normalization,
8485
)
86+
self.outer_projector = keras.layers.Dense(mlp_widths_outer[-1], kernel_initializer=kernel_initializer)
8587

8688
# Pooling function as keras layer for sum decomposition: inner( pooling( inner(set) ) )
8789
if pooling_kwargs is None:

0 commit comments

Comments
 (0)