Skip to content

Commit 4a3c123

Browse files
committed
Uses EfficientNetB0 as segmentation model encoder backbone
1 parent 1d0cf50 commit 4a3c123

File tree

2 files changed

+246
-19
lines changed

2 files changed

+246
-19
lines changed

robosat/efficientnet.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
"""EfficientNet architecture.
2+
3+
See:
4+
- https://arxiv.org/abs/1905.11946 - EfficientNet
5+
- https://arxiv.org/abs/1801.04381 - MobileNet V2
6+
- https://arxiv.org/abs/1905.02244 - MobileNet V3
7+
- https://arxiv.org/abs/1709.01507 - Squeeze-and-Excitation
8+
- https://arxiv.org/abs/1803.02579 - Concurrent spatial and channel squeeze-and-excitation
9+
- https://arxiv.org/abs/1812.01187 - Bag of Tricks for Image Classification with Convolutional Neural Networks
10+
11+
12+
Known issues:
13+
14+
- Not using swish activation function: unclear where, if, and how
15+
much it helps. Needs more experimentation. See also MobileNet V3.
16+
17+
- Not using squeeze and excitation blocks: I had significantly worse
18+
results with scse blocks, and cse blocks alone did not help, too.
19+
Needs more experimentation as it was done on small datasets only.
20+
21+
- Not using DropConnect: no efficient native implementation in PyTorch.
22+
Unclear if and how much it helps over Dropout.
23+
"""
24+
25+
import math
26+
import collections
27+
28+
import torch
29+
import torch.nn as nn
30+
31+
32+
EfficientNetParam = collections.namedtuple("EfficientNetParam", [
33+
"width", "depth", "resolution", "dropout"])
34+
35+
EfficientNetParams = {
36+
"B0": EfficientNetParam(1.0, 1.0, 224, 0.2),
37+
"B1": EfficientNetParam(1.0, 1.1, 240, 0.2),
38+
"B2": EfficientNetParam(1.1, 1.2, 260, 0.3),
39+
"B3": EfficientNetParam(1.2, 1.4, 300, 0.3),
40+
"B4": EfficientNetParam(1.4, 1.8, 380, 0.4),
41+
"B5": EfficientNetParam(1.6, 2.2, 456, 0.4),
42+
"B6": EfficientNetParam(1.8, 2.6, 528, 0.5),
43+
"B7": EfficientNetParam(2.0, 3.1, 600, 0.5)}
44+
45+
46+
def efficientnet0(pretrained=False, progress=False, num_classes=1000):
47+
return EfficientNet(param=EfficientNetParams["B0"], num_classes=num_classes)
48+
49+
def efficientnet1(pretrained=False, progress=False, num_classes=1000):
50+
return EfficientNet(param=EfficientNetParams["B1"], num_classes=num_classes)
51+
52+
def efficientnet2(pretrained=False, progress=False, num_classes=1000):
53+
return EfficientNet(param=EfficientNetParams["B2"], num_classes=num_classes)
54+
55+
def efficientnet3(pretrained=False, progress=False, num_classes=1000):
56+
return EfficientNet(param=EfficientNetParams["B3"], num_classes=num_classes)
57+
58+
def efficientnet4(pretrained=False, progress=False, num_classes=1000):
59+
return EfficientNet(param=EfficientNetParams["B4"], num_classes=num_classes)
60+
61+
def efficientnet5(pretrained=False, progress=False, num_classes=1000):
62+
return EfficientNet(param=EfficientNetParams["B5"], num_classes=num_classes)
63+
64+
def efficientnet6(pretrained=False, progress=False, num_classes=1000):
65+
return EfficientNet(param=EfficientNetParams["B6"], num_classes=num_classes)
66+
67+
def efficientnet7(pretrained=False, progress=False, num_classes=1000):
68+
return EfficientNet(param=EfficientNetParams["B7"], num_classes=num_classes)
69+
70+
71+
class EfficientNet(nn.Module):
72+
def __init__(self, param, num_classes=1000):
73+
super().__init__()
74+
75+
# For the exact scaling technique we follow the official implementation as the paper does not tell us
76+
# https://github.com/tensorflow/tpu/blob/01574500090fa9c011cb8418c61d442286720211/models/official/efficientnet/efficientnet_model.py#L101-L125
77+
78+
def scaled_depth(n):
79+
return int(math.ceil(n * param.depth))
80+
81+
# Snap number of channels to multiple of 8 for optimized implementations
82+
def scaled_width(n):
83+
n = n * param.width
84+
m = max(8, int(n + 8 / 2) // 8 * 8)
85+
86+
if m < 0.9 * n:
87+
m = m + 8
88+
89+
return int(m)
90+
91+
self.conv1 = nn.Conv2d(3, scaled_width(32), kernel_size=3, stride=2, padding=1, bias=False)
92+
self.bn1 = nn.BatchNorm2d(scaled_width(32))
93+
self.relu = nn.ReLU6(inplace=True)
94+
95+
self.layer1 = self._make_layer(n=scaled_depth(1), expansion=1, cin=scaled_width(32), cout=scaled_width(16), kernel_size=3, stride=1)
96+
self.layer2 = self._make_layer(n=scaled_depth(2), expansion=6, cin=scaled_width(16), cout=scaled_width(24), kernel_size=3, stride=2)
97+
self.layer3 = self._make_layer(n=scaled_depth(2), expansion=6, cin=scaled_width(24), cout=scaled_width(40), kernel_size=5, stride=2)
98+
self.layer4 = self._make_layer(n=scaled_depth(3), expansion=6, cin=scaled_width(40), cout=scaled_width(80), kernel_size=3, stride=2)
99+
self.layer5 = self._make_layer(n=scaled_depth(3), expansion=6, cin=scaled_width(80), cout=scaled_width(112), kernel_size=5, stride=1)
100+
self.layer6 = self._make_layer(n=scaled_depth(4), expansion=6, cin=scaled_width(112), cout=scaled_width(192), kernel_size=5, stride=2)
101+
self.layer7 = self._make_layer(n=scaled_depth(1), expansion=6, cin=scaled_width(192), cout=scaled_width(320), kernel_size=3, stride=1)
102+
103+
self.features = nn.Conv2d(scaled_width(320), scaled_width(1280), kernel_size=1, bias=False)
104+
105+
self.avgpool = nn.AdaptiveAvgPool2d(1)
106+
self.dropout = nn.Dropout(param.dropout, inplace=True)
107+
self.fc = nn.Linear(scaled_width(1280), num_classes)
108+
109+
for m in self.modules():
110+
if isinstance(m, nn.Conv2d):
111+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
112+
elif isinstance(m, nn.BatchNorm2d):
113+
nn.init.ones_(m.weight)
114+
nn.init.zeros_(m.bias)
115+
elif isinstance(m, nn.Linear):
116+
nn.init.normal_(m.weight, 0, 0.01)
117+
nn.init.zeros_(m.bias)
118+
119+
# Zero BatchNorm weight at end of res-blocks: identity by default
120+
# See https://arxiv.org/abs/1812.01187 Section 3.1
121+
for m in self.modules():
122+
if isinstance(m, Bottleneck):
123+
nn.init.zeros_(m.linear[1].weight)
124+
125+
126+
def _make_layer(self, n, expansion, cin, cout, kernel_size=3, stride=1):
127+
layers = []
128+
129+
for i in range(n):
130+
if i == 0:
131+
planes = cin
132+
expand = cin * expansion
133+
squeeze = cout
134+
stride = stride
135+
else:
136+
planes = cout
137+
expand = cout * expansion
138+
squeeze = cout
139+
stride = 1
140+
141+
layers += [Bottleneck(planes, expand, squeeze, kernel_size=kernel_size, stride=stride)]
142+
143+
return nn.Sequential(*layers)
144+
145+
146+
def forward(self, x):
147+
x = self.conv1(x)
148+
x = self.bn1(x)
149+
x = self.relu(x)
150+
151+
x = self.layer1(x)
152+
x = self.layer2(x)
153+
x = self.layer3(x)
154+
x = self.layer4(x)
155+
x = self.layer5(x)
156+
x = self.layer6(x)
157+
x = self.layer7(x)
158+
159+
x = self.features(x)
160+
161+
x = self.avgpool(x)
162+
x = x.reshape(x.size(0), -1)
163+
x = self.dropout(x)
164+
x = self.fc(x)
165+
166+
return x
167+
168+
169+
class Bottleneck(nn.Module):
170+
def __init__(self, planes, expand, squeeze, kernel_size, stride):
171+
super().__init__()
172+
173+
self.expand = nn.Identity() if planes == expand else nn.Sequential(
174+
nn.Conv2d(planes, expand, kernel_size=1, bias=False),
175+
nn.BatchNorm2d(expand),
176+
nn.ReLU6(inplace=True))
177+
178+
self.depthwise = nn.Sequential(
179+
nn.Conv2d(expand, expand, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=expand, bias=False),
180+
nn.BatchNorm2d(expand),
181+
nn.ReLU6(inplace=True))
182+
183+
self.linear = nn.Sequential(
184+
nn.Conv2d(expand, squeeze, kernel_size=1, bias=False),
185+
nn.BatchNorm2d(squeeze))
186+
187+
# Make all blocks skip-able via AvgPool + 1x1 Conv
188+
# See https://arxiv.org/abs/1812.01187 Figure 2 c
189+
190+
downsample = []
191+
192+
if stride != 1:
193+
downsample += [nn.AvgPool2d(kernel_size=stride, stride=stride)]
194+
195+
if planes != squeeze:
196+
downsample += [
197+
nn.Conv2d(planes, squeeze, kernel_size=1, stride=1, bias=False),
198+
nn.BatchNorm2d(squeeze)]
199+
200+
self.downsample = nn.Identity() if not downsample else nn.Sequential(*downsample)
201+
202+
203+
def forward(self, x):
204+
xx = self.expand(x)
205+
xx = self.depthwise(xx)
206+
xx = self.linear(xx)
207+
208+
x = self.downsample(x)
209+
xx.add_(x)
210+
211+
return xx

robosat/unet.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
import torch.nn as nn
1414

15-
from torchvision.models import resnet50
15+
from robosat.efficientnet import efficientnet0
1616

1717

1818
class ConvRelu(nn.Module):
@@ -91,17 +91,17 @@ def __init__(self, num_classes, num_filters=32, pretrained=True):
9191

9292
# Todo: make input channels configurable, not hard-coded to three channels for RGB
9393

94-
self.resnet = resnet50(pretrained=pretrained)
94+
self.net = efficientnet0(pretrained=pretrained)
9595

96-
# Access resnet directly in forward pass; do not store refs here due to
96+
# Access backbone directly in forward pass; do not store refs here due to
9797
# https://github.com/pytorch/pytorch/issues/8392
9898

99-
self.center = DecoderBlock(2048, num_filters * 8)
99+
self.center = DecoderBlock(1280, num_filters * 8)
100100

101-
self.dec0 = DecoderBlock(2048 + num_filters * 8, num_filters * 8)
102-
self.dec1 = DecoderBlock(1024 + num_filters * 8, num_filters * 8)
103-
self.dec2 = DecoderBlock(512 + num_filters * 8, num_filters * 2)
104-
self.dec3 = DecoderBlock(256 + num_filters * 2, num_filters * 2 * 2)
101+
self.dec0 = DecoderBlock(1280 + num_filters * 8, num_filters * 8)
102+
self.dec1 = DecoderBlock(112 + num_filters * 8, num_filters * 8)
103+
self.dec2 = DecoderBlock(40 + num_filters * 8, num_filters * 2)
104+
self.dec3 = DecoderBlock(24 + num_filters * 2, num_filters * 2 * 2)
105105
self.dec4 = DecoderBlock(num_filters * 2 * 2, num_filters)
106106
self.dec5 = ConvRelu(num_filters, num_filters)
107107

@@ -117,17 +117,33 @@ def forward(self, x):
117117
The networks output tensor.
118118
"""
119119
size = x.size()
120-
assert size[-1] % 32 == 0 and size[-2] % 32 == 0, "image resolution has to be divisible by 32 for resnet"
121-
122-
enc0 = self.resnet.conv1(x)
123-
enc0 = self.resnet.bn1(enc0)
124-
enc0 = self.resnet.relu(enc0)
125-
enc0 = self.resnet.maxpool(enc0)
126-
127-
enc1 = self.resnet.layer1(enc0)
128-
enc2 = self.resnet.layer2(enc1)
129-
enc3 = self.resnet.layer3(enc2)
130-
enc4 = self.resnet.layer4(enc3)
120+
assert size[-1] % 32 == 0 and size[-2] % 32 == 0, "image resolution has to be divisible by 32 for backbone"
121+
122+
# 1, 3, 512, 512
123+
enc0 = self.net.conv1(x)
124+
enc0 = self.net.bn1(enc0)
125+
enc0 = self.net.relu(enc0)
126+
# 1, 32, 256, 256
127+
enc0 = self.net.layer1(enc0)
128+
# 1, 16, 256, 256
129+
130+
enc1 = self.net.layer2(enc0)
131+
# 1, 24, 128, 128
132+
133+
enc2 = self.net.layer3(enc1)
134+
# 1, 40, 64, 64
135+
136+
enc3 = self.net.layer4(enc2)
137+
# 1, 80, 32, 32
138+
enc3 = self.net.layer5(enc3)
139+
# 1, 112, 32, 32
140+
141+
enc4 = self.net.layer6(enc3)
142+
# 1, 192, 16, 16
143+
enc4 = self.net.layer7(enc4)
144+
# 1, 320, 16, 16
145+
enc4 = self.net.features(enc4)
146+
# 1, 1280, 16, 16
131147

132148
center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2))
133149

0 commit comments

Comments
 (0)