@@ -28,7 +28,7 @@ def __init__(
2828 output_pooling : str = "mean" ,
2929 mlp_widths_equivariant : Sequence [int ] = (64 , 64 ),
3030 mlp_widths_invariant_inner : Sequence [int ] = (64 , 64 ),
31- mlp_widths_invariant_outer : Sequence [int ] = (64 , 64 ),
31+ mlp_widths_invariant_outer : Sequence [int ] = (64 , 4 ),
3232 mlp_widths_invariant_last : Sequence [int ] = (64 , 64 ),
3333 activation : str = "silu" ,
3434 kernel_initializer : str = "he_normal" ,
@@ -68,7 +68,7 @@ def __init__(
6868 mlp_widths_invariant_inner : Sequence[int], optional
6969 Widths of the inner MLP layers within the invariant module. Default is (64, 64).
7070 mlp_widths_invariant_outer : Sequence[int], optional
71- Widths of the outer MLP layers within the invariant module. Default is (64, 64 ).
71+ Widths of the outer MLP layers within the invariant module. Default is (64, 4 ).
7272 mlp_widths_invariant_last : Sequence[int], optional
7373 Widths of the MLP layers in the final invariant transformation. Default is (64, 64).
7474 activation : str, optional
@@ -80,7 +80,7 @@ def __init__(
8080 spectral_normalization : bool, optional
8181 Whether to apply spectral normalization to stabilize training. Default is False.
8282 **kwargs
83- Additional keyword arguments passed to the equivariant and invariant modules .
83+ Additional keyword arguments passed to the base class .
8484 """
8585
8686 super ().__init__ (** kwargs )
0 commit comments