|
9 | 9 | # See the License for the specific language governing permissions and |
10 | 10 | # limitations under the License. |
11 | 11 |
|
| 12 | +from typing import Tuple, Union |
| 13 | + |
12 | 14 | import torch.nn as nn |
13 | 15 |
|
| 16 | +from monai.networks.layers import get_act_layer |
| 17 | +from monai.utils import look_up_option |
| 18 | + |
| 19 | +SUPPORTED_DROPOUT_MODE = {"vit", "swin"} |
| 20 | + |
14 | 21 |
|
15 | 22 | class MLPBlock(nn.Module): |
16 | 23 | """ |
17 | 24 | A multi-layer perceptron block, based on: "Dosovitskiy et al., |
18 | 25 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>" |
19 | 26 | """ |
20 | 27 |
|
21 | | - def __init__(self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0) -> None: |
| 28 | + def __init__( |
| 29 | + self, |
| 30 | + hidden_size: int, |
| 31 | + mlp_dim: int, |
| 32 | + dropout_rate: float = 0.0, |
| 33 | + act: Union[Tuple, str] = "GELU", |
| 34 | + dropout_mode="vit", |
| 35 | + ) -> None: |
22 | 36 | """ |
23 | 37 | Args: |
24 | 38 | hidden_size: dimension of hidden layer. |
25 | | - mlp_dim: dimension of feedforward layer. |
| 39 | + mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used. |
26 | 40 | dropout_rate: faction of the input units to drop. |
| 41 | + act: activation type and arguments. Defaults to GELU. |
| 42 | + dropout_mode: dropout mode, can be "vit" or "swin". |
| 43 | + "vit" mode uses two dropout instances as implemented in |
| 44 | + https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87 |
| 45 | + "swin" corresponds to one instance as implemented in |
| 46 | + https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23 |
| 47 | +
|
27 | 48 |
|
28 | 49 | """ |
29 | 50 |
|
30 | 51 | super().__init__() |
31 | 52 |
|
32 | 53 | if not (0 <= dropout_rate <= 1): |
33 | 54 | raise ValueError("dropout_rate should be between 0 and 1.") |
34 | | - |
| 55 | + mlp_dim = mlp_dim or hidden_size |
35 | 56 | self.linear1 = nn.Linear(hidden_size, mlp_dim) |
36 | 57 | self.linear2 = nn.Linear(mlp_dim, hidden_size) |
37 | | - self.fn = nn.GELU() |
| 58 | + self.fn = get_act_layer(act) |
38 | 59 | self.drop1 = nn.Dropout(dropout_rate) |
39 | | - self.drop2 = nn.Dropout(dropout_rate) |
| 60 | + dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE) |
| 61 | + if dropout_opt == "vit": |
| 62 | + self.drop2 = nn.Dropout(dropout_rate) |
| 63 | + elif dropout_opt == "swin": |
| 64 | + self.drop2 = self.drop1 |
| 65 | + else: |
| 66 | + raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}") |
40 | 67 |
|
41 | 68 | def forward(self, x): |
42 | 69 | x = self.fn(self.linear1(x)) |
|
0 commit comments