Skip to content

Commit 336c7c7

Browse files
committed
增加权重初始化
1 parent 236a0f2 commit 336c7c7

File tree

6 files changed

+83
-17
lines changed

6 files changed

+83
-17
lines changed

CIFAR10_code/nets/AlexNet.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# 定义2012的AlexNet
99
class AlexNet(nn.Module):
10-
def __init__(self,num_classes=10):
10+
def __init__(self,num_classes=10, init_weights=True):
1111
super(AlexNet,self).__init__()
1212
# 五个卷积层 输入 32 * 32 * 3
1313
self.conv1 = nn.Sequential(
@@ -43,7 +43,8 @@ def __init__(self,num_classes=10):
4343
nn.ReLU(),
4444
nn.Linear(84,num_classes)
4545
)
46-
46+
if init_weights:
47+
self._initialize_weights()
4748
def forward(self,x):
4849
x = self.conv1(x)
4950
x = self.conv2(x)
@@ -53,7 +54,19 @@ def forward(self,x):
5354
x = x.view(x.size()[0],-1)
5455
x = self.fc(x)
5556
return x
56-
57+
58+
def _initialize_weights(self):
59+
for m in self.modules():
60+
if isinstance(m, nn.Conv2d):
61+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
62+
if m.bias is not None:
63+
nn.init.constant_(m.bias, 0)
64+
elif isinstance(m, nn.BatchNorm2d):
65+
nn.init.constant_(m.weight, 1)
66+
nn.init.constant_(m.bias, 0)
67+
elif isinstance(m, nn.Linear):
68+
nn.init.normal_(m.weight, 0, 0.01)
69+
nn.init.constant_(m.bias, 0)
5770
def test():
5871
net = AlexNet()
5972
x = torch.randn(2,3,32,32)

CIFAR10_code/nets/DenseNet.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class DenseNet(nn.Module):
7272
->(384, 1, 1) -> [Linear] -> (10)
7373
7474
"""
75-
def __init__(self, num_blocks, growth_rate=12, reduction=0.5, num_classes=10):
75+
def __init__(self, num_blocks, growth_rate=12, reduction=0.5, num_classes=10, init_weights=True):
7676
super(DenseNet, self).__init__()
7777
self.growth_rate = growth_rate
7878
self.reduction = reduction
@@ -91,7 +91,8 @@ def __init__(self, num_blocks, growth_rate=12, reduction=0.5, num_classes=10):
9191
)
9292
self.classifier = nn.Linear(num_channels, num_classes)
9393

94-
self._initialize_weight()
94+
if init_weights:
95+
self._initialize_weights()
9596

9697
def _make_dense_layer(self, in_channels, nblock, transition=True):
9798
layers = []
@@ -104,12 +105,18 @@ def _make_dense_layer(self, in_channels, nblock, transition=True):
104105
layers += [Transition(in_channels, out_channels)]
105106
return nn.Sequential(*layers), out_channels
106107

107-
def _initialize_weight(self):
108+
def _initialize_weights(self):
108109
for m in self.modules():
109110
if isinstance(m, nn.Conv2d):
110-
nn.init.kaiming_normal_(m.weight.data)
111+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
111112
if m.bias is not None:
112-
m.bias.data.zero_()
113+
nn.init.constant_(m.bias, 0)
114+
elif isinstance(m, nn.BatchNorm2d):
115+
nn.init.constant_(m.weight, 1)
116+
nn.init.constant_(m.bias, 0)
117+
elif isinstance(m, nn.Linear):
118+
nn.init.normal_(m.weight, 0, 0.01)
119+
nn.init.constant_(m.bias, 0)
113120

114121
def forward(self, x):
115122
out = self.features(x)

CIFAR10_code/nets/LeNet-5.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn as nn
66

77
class LeNet5(nn.Module):
8-
def __init__(self, num_classes = 10):
8+
def __init__(self, num_classes = 10, init_weights=True):
99
super(LeNet5,self).__init__()
1010
self.conv1 = nn.Sequential(
1111
# 输入 32x32x3 -> 28x28x6 (32-5)/1 + 1=28
@@ -30,7 +30,8 @@ def __init__(self, num_classes = 10):
3030
nn.ReLU(),
3131
nn.Linear(84,num_classes)
3232
)
33-
33+
if init_weights:
34+
self._initialize_weights()
3435
def forward(self,x):
3536
x = self.conv1(x)
3637
x = self.conv2(x)
@@ -39,6 +40,18 @@ def forward(self,x):
3940
x = self.fc(x)
4041
return x
4142

43+
def _initialize_weights(self):
44+
for m in self.modules():
45+
if isinstance(m, nn.Conv2d):
46+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
47+
if m.bias is not None:
48+
nn.init.constant_(m.bias, 0)
49+
elif isinstance(m, nn.BatchNorm2d):
50+
nn.init.constant_(m.weight, 1)
51+
nn.init.constant_(m.bias, 0)
52+
elif isinstance(m, nn.Linear):
53+
nn.init.normal_(m.weight, 0, 0.01)
54+
nn.init.constant_(m.bias, 0)
4255

4356
def test():
4457
net = LeNet5()

CIFAR10_code/nets/MobileNetv1.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class MobileNet(nn.Module):
2929
# (128,2) means conv channel=128, conv stride=2, by default conv stride=1
3030
cfg = [64,(128,2),128,(256,2),256,(512,2),512,512,512,512,512,(1024,2),1024]
3131

32-
def __init__(self, num_classes=10,alpha=1.0,beta=1.0):
32+
def __init__(self, num_classes=10,alpha=1.0,beta=1.0,init_weights=True):
3333
super(MobileNet,self).__init__()
3434
self.conv1 = nn.Sequential(
3535
nn.Conv2d(3,32,kernel_size=3,stride=1,bias=False),
@@ -49,6 +49,9 @@ def _make_layers(self, in_channels):
4949
in_channels = out_channels
5050
return nn.Sequential(*layers)
5151

52+
if init_weights:
53+
self._initialize_weights()
54+
5255
def forward(self,x):
5356
x = self.conv1(x)
5457
x = self.layers(x)
@@ -57,7 +60,7 @@ def forward(self,x):
5760
x = self.linear(x)
5861
return x
5962

60-
def init_weight(self):
63+
def _initialize_weights(self):
6164
for w in self.modules():
6265
if isinstance(w, nn.Conv2d):
6366
nn.init.kaiming_normal_(w.weight, mode='fan_out')

CIFAR10_code/nets/ResNet.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class ResNet(nn.Module):
7979
-> (16, 16, 128) -> [Res3] -> (8, 8, 256) ->[Res4] -> (4, 4, 512) -> [AvgPool]
8080
-> (1, 1, 512) -> [Reshape] -> (512) -> [Linear] -> (10)
8181
"""
82-
def __init__(self, block, num_blocks, num_classes=10, verbose = False):
82+
def __init__(self, block, num_blocks, num_classes=10, verbose = False, init_weights=True):
8383
super(ResNet, self).__init__()
8484
self.verbose = verbose
8585
self.in_channels = 64
@@ -97,7 +97,11 @@ def __init__(self, block, num_blocks, num_classes=10, verbose = False):
9797
# 所以这里用了 4 x 4 的平均池化
9898
self.avg_pool = nn.AvgPool2d(kernel_size=4)
9999
self.classifer = nn.Linear(512 * block.expansion, num_classes)
100-
100+
101+
if init_weights:
102+
self._initialize_weights()
103+
104+
101105
def _make_layer(self, block, out_channels, num_blocks, stride):
102106
# 第一个block要进行降采样
103107
strides = [stride] + [1] * (num_blocks - 1)
@@ -130,6 +134,18 @@ def forward(self, x):
130134
out = self.classifer(out)
131135
return out
132136

137+
def _initialize_weights(self):
138+
for m in self.modules():
139+
if isinstance(m, nn.Conv2d):
140+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
141+
if m.bias is not None:
142+
nn.init.constant_(m.bias, 0)
143+
elif isinstance(m, nn.BatchNorm2d):
144+
nn.init.constant_(m.weight, 1)
145+
nn.init.constant_(m.bias, 0)
146+
elif isinstance(m, nn.Linear):
147+
nn.init.normal_(m.weight, 0, 0.01)
148+
nn.init.constant_(m.bias, 0)
133149
def ResNet18(verbose=False):
134150
return ResNet(BasicBlock, [2,2,2,2],verbose=verbose)
135151

CIFAR10_code/nets/VGG.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
}
1414

1515
class VGG(nn.Module):
16-
def __init__(self, vggname = 'VGG16',num_classes=10):
16+
def __init__(self, vggname = 'VGG16',num_classes=10, init_weights=True):
1717
super(VGG,self).__init__()
1818
self.features = self._make_layers(cfg[vggname])
1919
self.classifier = nn.Linear(512,num_classes)
@@ -31,13 +31,27 @@ def _make_layers(self,cfg):
3131
in_channels = x
3232
layers += [nn.AvgPool2d(kernel_size=1,stride=1)]
3333
return nn.Sequential(*layers)
34-
34+
if init_weights:
35+
self. _initialize_weight()
3536
def forward(self,x):
3637
x = self.features(x)
3738
x = x.view(x.size()[0],-1)
3839
x = self.classifier(x)
3940
return x
40-
41+
# 初始化参数
42+
def _initialize_weight(self):
43+
for m in self.modules():
44+
if isinstance(m, nn.Conv2d):
45+
# xavier is used in VGG's paper
46+
nn.init.xavier_normal_(m.weight.data)
47+
if m.bias is not None:
48+
m.bias.data.zero_()
49+
elif isinstance(m, nn.BatchNorm2d):
50+
m.weight.data.fill_(1)
51+
m.bias.data.zero_()
52+
elif isinstance(m, nn.Linear):
53+
m.weight.data.normal_(0, 0.01)
54+
m.bias.data.zero_()
4155
def test():
4256
net = VGG('VGG19')
4357
x = torch.randn(2,3,32,32)

0 commit comments

Comments
 (0)