Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions bayesflow/networks/deep_set/equivariant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,13 @@ def __init__(
# Fully connected net + residual connection for an equivariant transform applied to each set member
self.input_projector = layers.Dense(mlp_widths_equivariant[-1])
self.equivariant_fc = MLP(
mlp_widths_equivariant,
mlp_widths_equivariant[:-1],
dropout=dropout,
activation=activation,
kernel_initializer=kernel_initializer,
spectral_normalization=spectral_normalization,
)
self.out_fc_projector = keras.layers.Dense(mlp_widths_equivariant[-1], kernel_initializer=kernel_initializer)

self.layer_norm = layers.LayerNormalization() if layer_norm else None

Expand Down Expand Up @@ -137,7 +138,10 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
output_set = ops.concatenate([input_set, invariant_summary], axis=-1)

# Pass through final equivariant transform + residual
output_set = input_set + self.equivariant_fc(output_set, training=training)
out_fc = self.equivariant_fc(output_set, training=training)
out_projected = self.out_fc_projector(out_fc)
output_set = input_set + out_projected

if self.layer_norm is not None:
output_set = self.layer_norm(output_set, training=training)

Expand Down
8 changes: 6 additions & 2 deletions bayesflow/networks/deep_set/invariant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,22 @@ def __init__(

# Inner fully connected net for sum decomposition: inner( pooling( inner(set) ) )
self.inner_fc = MLP(
mlp_widths_inner,
mlp_widths_inner[:-1],
dropout=dropout,
activation=activation,
kernel_initializer=kernel_initializer,
spectral_normalization=spectral_normalization,
)
self.inner_projector = keras.layers.Dense(mlp_widths_inner[-1], kernel_initializer=kernel_initializer)

self.outer_fc = MLP(
mlp_widths_outer,
mlp_widths_outer[:-1],
dropout=dropout,
activation=activation,
kernel_initializer=kernel_initializer,
spectral_normalization=spectral_normalization,
)
self.outer_projector = keras.layers.Dense(mlp_widths_outer[-1], kernel_initializer=kernel_initializer)

# Pooling function as keras layer for sum decomposition: inner( pooling( inner(set) ) )
if pooling_kwargs is None:
Expand All @@ -106,8 +108,10 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
"""

set_summary = self.inner_fc(input_set, training=training)
set_summary = self.inner_projector(set_summary, training=training)
set_summary = self.pooling_layer(set_summary, training=training)
set_summary = self.outer_fc(set_summary, training=training)
set_summary = self.outer_projector(set_summary, training=training)
return set_summary

@sanitize_input_shape
Expand Down