55
66from .base_modules import Activation , Norm
77
8- __all__ = ["Mlp" , "MlpBlock" ]
8+ __all__ = ["Mlp" , "ConvMlp" , " MlpBlock" ]
99
1010
1111class Mlp (nn .Module ):
@@ -17,7 +17,8 @@ def __init__(
1717 dropout : float = 0.0 ,
1818 bias : bool = False ,
1919 out_channels : int = None ,
20- ** act_kwargs
20+ act_kwargs : Dict [str , Any ] = None ,
21+ ** kwargs ,
2122 ) -> None :
2223 """MLP token mixer.
2324
@@ -32,7 +33,7 @@ def __init__(
3233 in_channels : int
3334 Number of input features.
3435 mlp_ratio : int, default=2
35- Scaling factor to get the number hidden features from the `in_features `.
36+ Scaling factor to get the number hidden features from the `in_channels `.
3637 activation : str, default="star_relu"
3738 The name of the activation function.
3839 dropout : float, default=0.0
@@ -41,10 +42,11 @@ def __init__(
4142 Flag whether to use bias terms in the nn.Linear modules.
4243 out_channels : int, optional
4344 Number of out channels. If None `out_channels = in_channels`
44- ** act_kwargs:
45+ act_kwargs : Dict[str, Any], optional
4546 Arbitrary key-word arguments for the activation function.
4647 """
4748 super ().__init__ ()
49+ act_kwargs = act_kwargs if act_kwargs is not None else {}
4850 self .out_channels = in_channels if out_channels is None else out_channels
4951 hidden_channels = int (mlp_ratio * in_channels )
5052
@@ -65,13 +67,73 @@ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
6567 return x
6668
6769
70+ class ConvMlp (nn .Module ):
71+ def __init__ (
72+ self ,
73+ in_channels : int ,
74+ mlp_ratio : int = 2 ,
75+ activation : str = "star_relu" ,
76+ dropout : float = 0.0 ,
77+ bias : bool = False ,
78+ out_channels : int = None ,
79+ act_kwargs : Dict [str , Any ] = None ,
80+ ** kwargs ,
81+ ) -> None :
82+ """Mlp layer implemented with dws convolution.
83+
84+ Input shape: (B, in_channels, H, W).
85+ Output shape: (B, out_channels, H, W).
86+
87+ Parameters
88+ ----------
89+ in_channels : int
90+ Number of input features.
91+ mlp_ratio : int, default=2
92+ Scaling factor to get the number hidden features from the `in_channels`.
93+ activation : str, default="star_relu"
94+ The name of the activation function.
95+ dropout : float, default=0.0
96+ Dropout ratio.
97+ bias : bool, default=False
98+ Flag whether to use bias terms in the nn.Linear modules.
99+ out_channels : int, optional
100+ Number of out channels. If None `out_channels = in_channels`
101+ act_kwargs : Dict[str, Any], optional
102+ Arbitrary key-word arguments for the activation function.
103+ """
104+ super ().__init__ ()
105+ act_kwargs = act_kwargs if act_kwargs is not None else {}
106+ self .out_channels = in_channels if out_channels is None else out_channels
107+ self .hidden_channels = int (mlp_ratio * in_channels )
108+ self .fc1 = nn .Conv2d (in_channels , self .hidden_channels , 1 , bias = bias )
109+ self .dwconv = nn .Conv2d (
110+ in_channels , in_channels , 3 , 1 , 1 , bias = bias , groups = in_channels
111+ )
112+ self .act = Activation (activation , ** act_kwargs )
113+ self .fc2 = nn .Conv2d (self .hidden_channels , self .out_channels , 1 , bias = bias )
114+ self .drop = nn .Dropout (dropout )
115+
116+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
117+ """Forward pass of conv-mlp."""
118+ x = self .fc1 (x )
119+
120+ x = self .dwconv (x )
121+ x = self .act (x )
122+ x = self .drop (x )
123+ x = self .fc2 (x )
124+ x = self .drop (x )
125+
126+ return x
127+
128+
68129class MlpBlock (nn .Module ):
69130 def __init__ (
70131 self ,
71132 in_channels : int ,
133+ mlp_type : str = "linear" ,
72134 mlp_ratio : int = 2 ,
73135 activation : str = "star_relu" ,
74- activation_kwargs : Dict [str , Any ] = None ,
136+ act_kwargs : Dict [str , Any ] = None ,
75137 dropout : float = 0.0 ,
76138 bias : bool = False ,
77139 normalization : str = "ln" ,
@@ -85,10 +147,15 @@ def __init__(
85147 ----------
86148 in_channels : int
87149 Number of input features.
150+ mlp_type : str, default="linear"
151+ Flag for either nn.Linear or nn.Conv2d mlp-layer.
152+ One of "conv", "linear".
88153 mlp_ratio : int, default=2
89- Scaling factor to get the number hidden features from the `in_features `.
154+ Scaling factor to get the number hidden features from the `in_channels `.
90155 activation : str, default="star_relu"
91156 The name of the activation function.
157+ act_kwargs : Dict[str, Any], optional
158+ key-word args for the activation module.
92159 dropout : float, default=0.0
93160 Dropout ratio.
94161 bias : bool, default=False
@@ -101,14 +168,24 @@ def __init__(
101168 is None.
102169 """
103170 super ().__init__ ()
171+ allowed = ("conv" , "linear" )
172+ if mlp_type not in allowed :
173+ raise ValueError (
174+ f"Illegal `mlp_type` given. Got: { mlp_type } . Allowed: { allowed } ."
175+ )
176+
177+ norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
178+ act_kwargs = act_kwargs if act_kwargs is not None else {}
104179 self .norm = Norm (normalization , ** norm_kwargs )
105- self .mlp = Mlp (
180+ MlpHead = Mlp if mlp_type == "linear" else ConvMlp
181+
182+ self .mlp = MlpHead (
106183 in_channels = in_channels ,
107184 mlp_ratio = mlp_ratio ,
108185 activation = activation ,
109186 dropout = dropout ,
110187 bias = bias ,
111- ** activation_kwargs
188+ act_kwargs = act_kwargs ,
112189 )
113190
114191 def forward (self , x : torch .Tensor ) -> torch .Tensor :
0 commit comments