1+ '''
2+ MobileNetV2 in PyTorch.
3+
4+ See the paper "Inverted Residuals and Linear Bottlenecks:
5+ Mobile Networks for Classification, Detection and Segmentation" for more details.
6+ '''
7+ import torch
8+ import torch .nn as nn
9+
10+
11+ class Block (nn .Module ):
12+ # 使用了inverted residuals
13+ '''expand + depthwise + pointwise'''
14+ def __init__ (self , in_channels , out_channels , expansion , stride ):
15+ super (Block , self ).__init__ ()
16+ self .stride = stride
17+ channels = expansion * in_channels # 倒残差结构先升维 再降维
18+ self .conv1 = nn .Conv2d (in_channels , channels , kernel_size = 1 , stride = 1 , padding = 0 , bias = False )
19+ self .bn1 = nn .BatchNorm2d (channels )
20+ self .conv2 = nn .Conv2d (channels ,channels ,kernel_size = 3 ,stride = stride ,padding = 1 , groups = channels , bias = False )
21+ self .bn2 = nn .BatchNorm2d (channels )
22+ self .conv3 = nn .Conv2d (channels , out_channels , kernel_size = 1 ,stride = 1 , padding = 0 , bias = False )
23+ self .bn3 = nn .BatchNorm2d (out_channels )
24+
25+ self .shortcut = nn .Sequential ()
26+ if stride == 1 and in_channels != out_channels :
27+ self .shortcut = nn .Sequential (
28+ nn .Conv2d (in_channels , out_channels , kernel_size = 1 , stride = 1 , padding = 0 , bias = False ),
29+ nn .BatchNorm2d (out_channels )
30+ )
31+ self .relu6 = nn .ReLU6 ()
32+ def forward (self , x ):
33+ out = self .relu6 (self .bn1 (self .conv1 (x )))
34+ out = self .relu6 (self .bn2 (self .conv2 (out )))
35+ out = self .bn3 (self .conv3 (out ))
36+ out = out + self .shortcut (x ) if self .stride == 1 else out
37+
38+ return out
39+
40+ class MobileNetV2 (nn .Module ):
41+ # (expansion, out_channels, num_blocks, stride)
42+ cfg = [(1 , 16 , 1 , 1 ),
43+ (6 , 24 , 2 , 1 ), # NOTE: change stride 2 -> 1 for CIFAR10
44+ (6 , 32 , 3 , 2 ),
45+ (6 , 64 , 4 , 2 ),
46+ (6 , 96 , 3 , 1 ),
47+ (6 , 160 , 3 , 2 ),
48+ (6 , 320 , 1 , 1 )]
49+
50+ def __init__ (self , num_classes = 10 ):
51+ super (MobileNetV2 , self ).__init__ ()
52+ # NOTE: change conv1 stride 2 -> 1 for CIFAR10
53+ self .conv1 = nn .Conv2d (3 , 32 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False )
54+ self .bn1 = nn .BatchNorm2d (32 )
55+ self .layers = self ._make_layers (in_channels = 32 )
56+ self .conv2 = nn .Conv2d (320 , 1280 , kernel_size = 1 , stride = 1 , padding = 0 , bias = False )
57+ self .bn2 = nn .BatchNorm2d (1280 )
58+ self .avgpool = nn .AdaptiveAvgPool2d (1 )
59+ self .linear = nn .Linear (1280 , num_classes )
60+ self .relu6 = nn .ReLU6 ()
61+
62+ def _make_layers (self , in_channels ):
63+ layers = []
64+ for expansion , out_channels , num_block , stride in self .cfg :
65+ strides = [stride ] + [1 ]* (num_block - 1 )
66+ for stride in strides :
67+ layers .append (Block (in_channels , out_channels , expansion , stride ))
68+ in_channels = out_channels
69+ return nn .Sequential (* layers )
70+
71+ def init_weight (self ):
72+ for w in self .modules ():
73+ if isinstance (w , nn .Conv2d ):
74+ nn .init .kaiming_normal_ (w .weight , mode = 'fan_out' )
75+ if w .bias is not None :
76+ nn .init .zeros_ (w .bias )
77+ elif isinstance (w , nn .BatchNorm2d ):
78+ nn .init .ones_ (w .weight )
79+ nn .init .zeros_ (w .bias )
80+ elif isinstance (w , nn .Linear ):
81+ nn .init .normal_ (w .weight , 0 , 0.01 )
82+ nn .init .zeros_ (w .bias )
83+
84+ def forward (self , x ):
85+ out = self .relu6 (self .bn1 (self .conv1 (x )))
86+ out = self .layers (out )
87+ out = self .relu6 (self .bn2 (self .conv2 (out )))
88+ out = self .avgpool (out )
89+ out = out .view (out .size (0 ),- 1 )
90+ out = self .linear (out )
91+ return out
92+
93+
94+ def test ():
95+ net = MobileNetV2 ()
96+ x = torch .randn (2 ,3 ,32 ,32 )
97+ y = net (x )
98+ print (y .size ())
99+ from torchinfo import summary
100+ device = 'cuda' if torch .cuda .is_available () else 'cpu'
101+ net = net .to (device )
102+ summary (net ,(32 ,3 ,32 ,32 ))
103+
104+ test ()
0 commit comments