Skip to content

Commit 17529e7

Browse files
authored
extending the mlp module (#4089)
* extend mlp Signed-off-by: Wenqi Li <[email protected]> * 0 mlp_dim Signed-off-by: Wenqi Li <[email protected]> * update based on comments Signed-off-by: Wenqi Li <[email protected]>
1 parent d68554e commit 17529e7

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

monai/networks/blocks/mlp.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,61 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from typing import Tuple, Union
13+
1214
import torch.nn as nn
1315

16+
from monai.networks.layers import get_act_layer
17+
from monai.utils import look_up_option
18+
19+
SUPPORTED_DROPOUT_MODE = {"vit", "swin"}
20+
1421

1522
class MLPBlock(nn.Module):
1623
"""
1724
A multi-layer perceptron block, based on: "Dosovitskiy et al.,
1825
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
1926
"""
2027

21-
def __init__(self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0) -> None:
28+
def __init__(
29+
self,
30+
hidden_size: int,
31+
mlp_dim: int,
32+
dropout_rate: float = 0.0,
33+
act: Union[Tuple, str] = "GELU",
34+
dropout_mode="vit",
35+
) -> None:
2236
"""
2337
Args:
2438
hidden_size: dimension of hidden layer.
25-
mlp_dim: dimension of feedforward layer.
39+
mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used.
2640
dropout_rate: faction of the input units to drop.
41+
act: activation type and arguments. Defaults to GELU.
42+
dropout_mode: dropout mode, can be "vit" or "swin".
43+
"vit" mode uses two dropout instances as implemented in
44+
https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87
45+
"swin" corresponds to one instance as implemented in
46+
https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23
47+
2748
2849
"""
2950

3051
super().__init__()
3152

3253
if not (0 <= dropout_rate <= 1):
3354
raise ValueError("dropout_rate should be between 0 and 1.")
34-
55+
mlp_dim = mlp_dim or hidden_size
3556
self.linear1 = nn.Linear(hidden_size, mlp_dim)
3657
self.linear2 = nn.Linear(mlp_dim, hidden_size)
37-
self.fn = nn.GELU()
58+
self.fn = get_act_layer(act)
3859
self.drop1 = nn.Dropout(dropout_rate)
39-
self.drop2 = nn.Dropout(dropout_rate)
60+
dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE)
61+
if dropout_opt == "vit":
62+
self.drop2 = nn.Dropout(dropout_rate)
63+
elif dropout_opt == "swin":
64+
self.drop2 = self.drop1
65+
else:
66+
raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}")
4067

4168
def forward(self, x):
4269
x = self.fn(self.linear1(x))

tests/test_mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
TEST_CASE_MLP = []
2222
for dropout_rate in np.linspace(0, 1, 4):
2323
for hidden_size in [128, 256, 512, 768]:
24-
for mlp_dim in [512, 1028, 2048, 3072]:
24+
for mlp_dim in [0, 1028, 2048, 3072]:
2525

2626
test_case = [
2727
{"hidden_size": hidden_size, "mlp_dim": mlp_dim, "dropout_rate": dropout_rate},

0 commit comments

Comments
 (0)