Skip to content

Commit 65b917d

Browse files
committed
构建MobileNetv2
1 parent 03c9da4 commit 65b917d

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed

CIFAR10_code/nets/MobileNetv2.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
'''
2+
MobileNetV2 in PyTorch.
3+
4+
See the paper "Inverted Residuals and Linear Bottlenecks:
5+
Mobile Networks for Classification, Detection and Segmentation" for more details.
6+
'''
7+
import torch
8+
import torch.nn as nn
9+
10+
11+
class Block(nn.Module):
12+
# 使用了inverted residuals
13+
'''expand + depthwise + pointwise'''
14+
def __init__(self, in_channels, out_channels, expansion, stride):
15+
super(Block, self).__init__()
16+
self.stride = stride
17+
channels = expansion * in_channels # 倒残差结构先升维 再降维
18+
self.conv1 = nn.Conv2d(in_channels, channels, kernel_size = 1, stride=1, padding=0, bias=False)
19+
self.bn1 = nn.BatchNorm2d(channels)
20+
self.conv2 = nn.Conv2d(channels,channels,kernel_size=3,stride=stride,padding=1, groups=channels, bias=False)
21+
self.bn2 = nn.BatchNorm2d(channels)
22+
self.conv3 = nn.Conv2d(channels, out_channels, kernel_size=1,stride=1, padding = 0, bias=False)
23+
self.bn3 = nn.BatchNorm2d(out_channels)
24+
25+
self.shortcut = nn.Sequential()
26+
if stride == 1 and in_channels != out_channels:
27+
self.shortcut = nn.Sequential(
28+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
29+
nn.BatchNorm2d(out_channels)
30+
)
31+
self.relu6 = nn.ReLU6()
32+
def forward(self, x):
33+
out = self.relu6(self.bn1(self.conv1(x)))
34+
out = self.relu6(self.bn2(self.conv2(out)))
35+
out = self.bn3(self.conv3(out))
36+
out = out + self.shortcut(x) if self.stride == 1 else out
37+
38+
return out
39+
40+
class MobileNetV2(nn.Module):
41+
# (expansion, out_channels, num_blocks, stride)
42+
cfg = [(1, 16, 1, 1),
43+
(6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10
44+
(6, 32, 3, 2),
45+
(6, 64, 4, 2),
46+
(6, 96, 3, 1),
47+
(6, 160, 3, 2),
48+
(6, 320, 1, 1)]
49+
50+
def __init__(self, num_classes=10):
51+
super(MobileNetV2, self).__init__()
52+
# NOTE: change conv1 stride 2 -> 1 for CIFAR10
53+
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
54+
self.bn1 = nn.BatchNorm2d(32)
55+
self.layers = self._make_layers(in_channels=32)
56+
self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
57+
self.bn2 = nn.BatchNorm2d(1280)
58+
self.avgpool = nn.AdaptiveAvgPool2d(1)
59+
self.linear = nn.Linear(1280, num_classes)
60+
self.relu6 = nn.ReLU6()
61+
62+
def _make_layers(self, in_channels):
63+
layers = []
64+
for expansion, out_channels, num_block, stride in self.cfg:
65+
strides = [stride] + [1]*(num_block-1)
66+
for stride in strides:
67+
layers.append(Block(in_channels, out_channels, expansion, stride))
68+
in_channels = out_channels
69+
return nn.Sequential(*layers)
70+
71+
def init_weight(self):
72+
for w in self.modules():
73+
if isinstance(w, nn.Conv2d):
74+
nn.init.kaiming_normal_(w.weight, mode='fan_out')
75+
if w.bias is not None:
76+
nn.init.zeros_(w.bias)
77+
elif isinstance(w, nn.BatchNorm2d):
78+
nn.init.ones_(w.weight)
79+
nn.init.zeros_(w.bias)
80+
elif isinstance(w, nn.Linear):
81+
nn.init.normal_(w.weight, 0, 0.01)
82+
nn.init.zeros_(w.bias)
83+
84+
def forward(self, x):
85+
out = self.relu6(self.bn1(self.conv1(x)))
86+
out = self.layers(out)
87+
out = self.relu6(self.bn2(self.conv2(out)))
88+
out = self.avgpool(out)
89+
out = out.view(out.size(0),-1)
90+
out = self.linear(out)
91+
return out
92+
93+
94+
def test():
95+
net = MobileNetV2()
96+
x = torch.randn(2,3,32,32)
97+
y = net(x)
98+
print(y.size())
99+
from torchinfo import summary
100+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
101+
net = net.to(device)
102+
summary(net,(32,3,32,32))
103+
104+
test()

0 commit comments

Comments
 (0)