Skip to content

Commit 7e0363e

Browse files
committed
implement requested changes and improve activation
1 parent 82d5bc1 commit 7e0363e

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

bayesflow/networks/deep_set/deep_set.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
mlp_widths_invariant_inner: Sequence[int] = (64, 64),
3131
mlp_widths_invariant_outer: Sequence[int] = (64, 64),
3232
mlp_widths_invariant_last: Sequence[int] = (64, 64),
33-
activation: str = "gelu",
33+
activation: str = "silu",
3434
kernel_initializer: str = "he_normal",
3535
dropout: int | float | None = 0.05,
3636
spectral_normalization: bool = False,
@@ -72,7 +72,7 @@ def __init__(
7272
mlp_widths_invariant_last : Sequence[int], optional
7373
Widths of the MLP layers in the final invariant transformation. Default is (64, 64).
7474
activation : str, optional
75-
Activation function used throughout the network, such as "gelu". Default is "gelu".
75+
Activation function used throughout the network, such as "gelu". Default is "silu".
7676
kernel_initializer : str, optional
7777
Initialization strategy for kernel weights, such as "he_normal". Default is "he_normal".
7878
dropout : int, float, or None, optional

bayesflow/networks/deep_set/equivariant_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
# Fully connected net + residual connection for an equivariant transform applied to each set member
8989
self.input_projector = layers.Dense(mlp_widths_equivariant[-1])
9090
self.equivariant_fc = MLP(
91-
mlp_widths_equivariant[:-1],
91+
mlp_widths_equivariant,
9292
dropout=dropout,
9393
activation=activation,
9494
kernel_initializer=kernel_initializer,

bayesflow/networks/deep_set/invariant_layer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868

6969
# Inner fully connected net for sum decomposition: inner( pooling( inner(set) ) )
7070
self.inner_fc = MLP(
71-
mlp_widths_inner[:-1],
71+
mlp_widths_inner,
7272
dropout=dropout,
7373
activation=activation,
7474
kernel_initializer=kernel_initializer,
@@ -77,7 +77,7 @@ def __init__(
7777
self.inner_projector = keras.layers.Dense(mlp_widths_inner[-1], kernel_initializer=kernel_initializer)
7878

7979
self.outer_fc = MLP(
80-
mlp_widths_outer[:-1],
80+
mlp_widths_outer,
8181
dropout=dropout,
8282
activation=activation,
8383
kernel_initializer=kernel_initializer,
@@ -108,10 +108,10 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
108108
"""
109109

110110
set_summary = self.inner_fc(input_set, training=training)
111-
set_summary = self.inner_projector(set_summary, training=training)
111+
set_summary = self.inner_projector(set_summary)
112112
set_summary = self.pooling_layer(set_summary, training=training)
113113
set_summary = self.outer_fc(set_summary, training=training)
114-
set_summary = self.outer_projector(set_summary, training=training)
114+
set_summary = self.outer_projector(set_summary)
115115
return set_summary
116116

117117
@sanitize_input_shape

0 commit comments

Comments
 (0)