diff --git a/bayesflow/networks/deep_set/deep_set.py b/bayesflow/networks/deep_set/deep_set.py index 4daf406b5..7d0427b77 100644 --- a/bayesflow/networks/deep_set/deep_set.py +++ b/bayesflow/networks/deep_set/deep_set.py @@ -28,7 +28,7 @@ def __init__( output_pooling: str = "mean", mlp_widths_equivariant: Sequence[int] = (64, 64), mlp_widths_invariant_inner: Sequence[int] = (64, 64), - mlp_widths_invariant_outer: Sequence[int] = (64, 64), + mlp_widths_invariant_outer: Sequence[int] = (64, 4), mlp_widths_invariant_last: Sequence[int] = (64, 64), activation: str = "silu", kernel_initializer: str = "he_normal", @@ -68,7 +68,7 @@ def __init__( mlp_widths_invariant_inner : Sequence[int], optional Widths of the inner MLP layers within the invariant module. Default is (64, 64). mlp_widths_invariant_outer : Sequence[int], optional - Widths of the outer MLP layers within the invariant module. Default is (64, 64). + Widths of the outer MLP layers within the invariant module. Default is (64, 4). mlp_widths_invariant_last : Sequence[int], optional Widths of the MLP layers in the final invariant transformation. Default is (64, 64). activation : str, optional @@ -80,7 +80,7 @@ def __init__( spectral_normalization : bool, optional Whether to apply spectral normalization to stabilize training. Default is False. **kwargs - Additional keyword arguments passed to the equivariant and invariant modules. + Additional keyword arguments passed to the base class. """ super().__init__(**kwargs)