|
| 1 | +"""EfficientNet architecture. |
| 2 | +
|
| 3 | +See: |
| 4 | +- https://arxiv.org/abs/1905.11946 - EfficientNet |
| 5 | +- https://arxiv.org/abs/1801.04381 - MobileNet V2 |
| 6 | +- https://arxiv.org/abs/1905.02244 - MobileNet V3 |
| 7 | +- https://arxiv.org/abs/1709.01507 - Squeeze-and-Excitation |
| 8 | +- https://arxiv.org/abs/1803.02579 - Concurrent spatial and channel squeeze-and-excitation |
| 9 | +- https://arxiv.org/abs/1812.01187 - Bag of Tricks for Image Classification with Convolutional Neural Networks |
| 10 | +
|
| 11 | +
|
| 12 | +Known issues: |
| 13 | +
|
| 14 | +- Not using swish activation function: unclear where, if, and how |
| 15 | + much it helps. Needs more experimentation. See also MobileNet V3. |
| 16 | +
|
| 17 | +- Not using squeeze and excitation blocks: I had significantly worse |
| 18 | + results with scse blocks, and cse blocks alone did not help, too. |
| 19 | + Needs more experimentation as it was done on small datasets only. |
| 20 | +
|
| 21 | +- Not using DropConnect: no efficient native implementation in PyTorch. |
| 22 | + Unclear if and how much it helps over Dropout. |
| 23 | +""" |
| 24 | + |
| 25 | +import math |
| 26 | +import collections |
| 27 | + |
| 28 | +import torch |
| 29 | +import torch.nn as nn |
| 30 | + |
| 31 | + |
| 32 | +EfficientNetParam = collections.namedtuple("EfficientNetParam", [ |
| 33 | + "width", "depth", "resolution", "dropout"]) |
| 34 | + |
| 35 | +EfficientNetParams = { |
| 36 | + "B0": EfficientNetParam(1.0, 1.0, 224, 0.2), |
| 37 | + "B1": EfficientNetParam(1.0, 1.1, 240, 0.2), |
| 38 | + "B2": EfficientNetParam(1.1, 1.2, 260, 0.3), |
| 39 | + "B3": EfficientNetParam(1.2, 1.4, 300, 0.3), |
| 40 | + "B4": EfficientNetParam(1.4, 1.8, 380, 0.4), |
| 41 | + "B5": EfficientNetParam(1.6, 2.2, 456, 0.4), |
| 42 | + "B6": EfficientNetParam(1.8, 2.6, 528, 0.5), |
| 43 | + "B7": EfficientNetParam(2.0, 3.1, 600, 0.5)} |
| 44 | + |
| 45 | + |
| 46 | +def efficientnet0(pretrained=False, progress=False, num_classes=1000): |
| 47 | + return EfficientNet(param=EfficientNetParams["B0"], num_classes=num_classes) |
| 48 | + |
| 49 | +def efficientnet1(pretrained=False, progress=False, num_classes=1000): |
| 50 | + return EfficientNet(param=EfficientNetParams["B1"], num_classes=num_classes) |
| 51 | + |
| 52 | +def efficientnet2(pretrained=False, progress=False, num_classes=1000): |
| 53 | + return EfficientNet(param=EfficientNetParams["B2"], num_classes=num_classes) |
| 54 | + |
| 55 | +def efficientnet3(pretrained=False, progress=False, num_classes=1000): |
| 56 | + return EfficientNet(param=EfficientNetParams["B3"], num_classes=num_classes) |
| 57 | + |
| 58 | +def efficientnet4(pretrained=False, progress=False, num_classes=1000): |
| 59 | + return EfficientNet(param=EfficientNetParams["B4"], num_classes=num_classes) |
| 60 | + |
| 61 | +def efficientnet5(pretrained=False, progress=False, num_classes=1000): |
| 62 | + return EfficientNet(param=EfficientNetParams["B5"], num_classes=num_classes) |
| 63 | + |
| 64 | +def efficientnet6(pretrained=False, progress=False, num_classes=1000): |
| 65 | + return EfficientNet(param=EfficientNetParams["B6"], num_classes=num_classes) |
| 66 | + |
| 67 | +def efficientnet7(pretrained=False, progress=False, num_classes=1000): |
| 68 | + return EfficientNet(param=EfficientNetParams["B7"], num_classes=num_classes) |
| 69 | + |
| 70 | + |
| 71 | +class EfficientNet(nn.Module): |
| 72 | + def __init__(self, param, num_classes=1000): |
| 73 | + super().__init__() |
| 74 | + |
| 75 | + # For the exact scaling technique we follow the official implementation as the paper does not tell us |
| 76 | + # https://github.com/tensorflow/tpu/blob/01574500090fa9c011cb8418c61d442286720211/models/official/efficientnet/efficientnet_model.py#L101-L125 |
| 77 | + |
| 78 | + def scaled_depth(n): |
| 79 | + return int(math.ceil(n * param.depth)) |
| 80 | + |
| 81 | + # Snap number of channels to multiple of 8 for optimized implementations |
| 82 | + def scaled_width(n): |
| 83 | + n = n * param.width |
| 84 | + m = max(8, int(n + 8 / 2) // 8 * 8) |
| 85 | + |
| 86 | + if m < 0.9 * n: |
| 87 | + m = m + 8 |
| 88 | + |
| 89 | + return int(m) |
| 90 | + |
| 91 | + self.conv1 = nn.Conv2d(3, scaled_width(32), kernel_size=3, stride=2, padding=1, bias=False) |
| 92 | + self.bn1 = nn.BatchNorm2d(scaled_width(32)) |
| 93 | + self.relu = nn.ReLU6(inplace=True) |
| 94 | + |
| 95 | + self.layer1 = self._make_layer(n=scaled_depth(1), expansion=1, cin=scaled_width(32), cout=scaled_width(16), kernel_size=3, stride=1) |
| 96 | + self.layer2 = self._make_layer(n=scaled_depth(2), expansion=6, cin=scaled_width(16), cout=scaled_width(24), kernel_size=3, stride=2) |
| 97 | + self.layer3 = self._make_layer(n=scaled_depth(2), expansion=6, cin=scaled_width(24), cout=scaled_width(40), kernel_size=5, stride=2) |
| 98 | + self.layer4 = self._make_layer(n=scaled_depth(3), expansion=6, cin=scaled_width(40), cout=scaled_width(80), kernel_size=3, stride=2) |
| 99 | + self.layer5 = self._make_layer(n=scaled_depth(3), expansion=6, cin=scaled_width(80), cout=scaled_width(112), kernel_size=5, stride=1) |
| 100 | + self.layer6 = self._make_layer(n=scaled_depth(4), expansion=6, cin=scaled_width(112), cout=scaled_width(192), kernel_size=5, stride=2) |
| 101 | + self.layer7 = self._make_layer(n=scaled_depth(1), expansion=6, cin=scaled_width(192), cout=scaled_width(320), kernel_size=3, stride=1) |
| 102 | + |
| 103 | + self.features = nn.Conv2d(scaled_width(320), scaled_width(1280), kernel_size=1, bias=False) |
| 104 | + |
| 105 | + self.avgpool = nn.AdaptiveAvgPool2d(1) |
| 106 | + self.dropout = nn.Dropout(param.dropout, inplace=True) |
| 107 | + self.fc = nn.Linear(scaled_width(1280), num_classes) |
| 108 | + |
| 109 | + for m in self.modules(): |
| 110 | + if isinstance(m, nn.Conv2d): |
| 111 | + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") |
| 112 | + elif isinstance(m, nn.BatchNorm2d): |
| 113 | + nn.init.ones_(m.weight) |
| 114 | + nn.init.zeros_(m.bias) |
| 115 | + elif isinstance(m, nn.Linear): |
| 116 | + nn.init.normal_(m.weight, 0, 0.01) |
| 117 | + nn.init.zeros_(m.bias) |
| 118 | + |
| 119 | + # Zero BatchNorm weight at end of res-blocks: identity by default |
| 120 | + # See https://arxiv.org/abs/1812.01187 Section 3.1 |
| 121 | + for m in self.modules(): |
| 122 | + if isinstance(m, Bottleneck): |
| 123 | + nn.init.zeros_(m.linear[1].weight) |
| 124 | + |
| 125 | + |
| 126 | + def _make_layer(self, n, expansion, cin, cout, kernel_size=3, stride=1): |
| 127 | + layers = [] |
| 128 | + |
| 129 | + for i in range(n): |
| 130 | + if i == 0: |
| 131 | + planes = cin |
| 132 | + expand = cin * expansion |
| 133 | + squeeze = cout |
| 134 | + stride = stride |
| 135 | + else: |
| 136 | + planes = cout |
| 137 | + expand = cout * expansion |
| 138 | + squeeze = cout |
| 139 | + stride = 1 |
| 140 | + |
| 141 | + layers += [Bottleneck(planes, expand, squeeze, kernel_size=kernel_size, stride=stride)] |
| 142 | + |
| 143 | + return nn.Sequential(*layers) |
| 144 | + |
| 145 | + |
| 146 | + def forward(self, x): |
| 147 | + x = self.conv1(x) |
| 148 | + x = self.bn1(x) |
| 149 | + x = self.relu(x) |
| 150 | + |
| 151 | + x = self.layer1(x) |
| 152 | + x = self.layer2(x) |
| 153 | + x = self.layer3(x) |
| 154 | + x = self.layer4(x) |
| 155 | + x = self.layer5(x) |
| 156 | + x = self.layer6(x) |
| 157 | + x = self.layer7(x) |
| 158 | + |
| 159 | + x = self.features(x) |
| 160 | + |
| 161 | + x = self.avgpool(x) |
| 162 | + x = x.reshape(x.size(0), -1) |
| 163 | + x = self.dropout(x) |
| 164 | + x = self.fc(x) |
| 165 | + |
| 166 | + return x |
| 167 | + |
| 168 | + |
| 169 | +class Bottleneck(nn.Module): |
| 170 | + def __init__(self, planes, expand, squeeze, kernel_size, stride): |
| 171 | + super().__init__() |
| 172 | + |
| 173 | + self.expand = nn.Identity() if planes == expand else nn.Sequential( |
| 174 | + nn.Conv2d(planes, expand, kernel_size=1, bias=False), |
| 175 | + nn.BatchNorm2d(expand), |
| 176 | + nn.ReLU6(inplace=True)) |
| 177 | + |
| 178 | + self.depthwise = nn.Sequential( |
| 179 | + nn.Conv2d(expand, expand, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=expand, bias=False), |
| 180 | + nn.BatchNorm2d(expand), |
| 181 | + nn.ReLU6(inplace=True)) |
| 182 | + |
| 183 | + self.linear = nn.Sequential( |
| 184 | + nn.Conv2d(expand, squeeze, kernel_size=1, bias=False), |
| 185 | + nn.BatchNorm2d(squeeze)) |
| 186 | + |
| 187 | + # Make all blocks skip-able via AvgPool + 1x1 Conv |
| 188 | + # See https://arxiv.org/abs/1812.01187 Figure 2 c |
| 189 | + |
| 190 | + downsample = [] |
| 191 | + |
| 192 | + if stride != 1: |
| 193 | + downsample += [nn.AvgPool2d(kernel_size=stride, stride=stride)] |
| 194 | + |
| 195 | + if planes != squeeze: |
| 196 | + downsample += [ |
| 197 | + nn.Conv2d(planes, squeeze, kernel_size=1, stride=1, bias=False), |
| 198 | + nn.BatchNorm2d(squeeze)] |
| 199 | + |
| 200 | + self.downsample = nn.Identity() if not downsample else nn.Sequential(*downsample) |
| 201 | + |
| 202 | + |
| 203 | + def forward(self, x): |
| 204 | + xx = self.expand(x) |
| 205 | + xx = self.depthwise(xx) |
| 206 | + xx = self.linear(xx) |
| 207 | + |
| 208 | + x = self.downsample(x) |
| 209 | + xx.add_(x) |
| 210 | + |
| 211 | + return xx |
0 commit comments