Skip to content

Commit d538ef0

Browse files
committed
取消初始化权重函数
1 parent 3add4e1 commit d538ef0

File tree

4 files changed

+5
-43
lines changed

4 files changed

+5
-43
lines changed

CIFAR10_code/nets/AlexNet.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# 定义2012的AlexNet
99
class AlexNet(nn.Module):
10-
def __init__(self,num_classes=10, init_weights=True):
10+
def __init__(self,num_classes=10):
1111
super(AlexNet,self).__init__()
1212
# 五个卷积层 输入 32 * 32 * 3
1313
self.conv1 = nn.Sequential(
@@ -43,8 +43,7 @@ def __init__(self,num_classes=10, init_weights=True):
4343
nn.ReLU(),
4444
nn.Linear(84,num_classes)
4545
)
46-
if init_weights:
47-
self._initialize_weights()
46+
4847
def forward(self,x):
4948
x = self.conv1(x)
5049
x = self.conv2(x)
@@ -55,18 +54,6 @@ def forward(self,x):
5554
x = self.fc(x)
5655
return x
5756

58-
def _initialize_weights(self):
59-
for m in self.modules():
60-
if isinstance(m, nn.Conv2d):
61-
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
62-
if m.bias is not None:
63-
nn.init.constant_(m.bias, 0)
64-
elif isinstance(m, nn.BatchNorm2d):
65-
nn.init.constant_(m.weight, 1)
66-
nn.init.constant_(m.bias, 0)
67-
elif isinstance(m, nn.Linear):
68-
nn.init.normal_(m.weight, 0, 0.01)
69-
nn.init.constant_(m.bias, 0)
7057
def test():
7158
net = AlexNet()
7259
x = torch.randn(2,3,32,32)

CIFAR10_code/nets/LeNet5.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn as nn
66

77
class LeNet5(nn.Module):
8-
def __init__(self, num_classes = 10, init_weights=True):
8+
def __init__(self, num_classes = 10):
99
super(LeNet5,self).__init__()
1010
self.conv1 = nn.Sequential(
1111
# 输入 32x32x3 -> 28x28x6 (32-5)/1 + 1=28
@@ -30,8 +30,7 @@ def __init__(self, num_classes = 10, init_weights=True):
3030
nn.ReLU(),
3131
nn.Linear(84,num_classes)
3232
)
33-
if init_weights:
34-
self._initialize_weights()
33+
3534
def forward(self,x):
3635
x = self.conv1(x)
3736
x = self.conv2(x)
@@ -40,18 +39,7 @@ def forward(self,x):
4039
x = self.fc(x)
4140
return x
4241

43-
def _initialize_weights(self):
44-
for m in self.modules():
45-
if isinstance(m, nn.Conv2d):
46-
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
47-
if m.bias is not None:
48-
nn.init.constant_(m.bias, 0)
49-
elif isinstance(m, nn.BatchNorm2d):
50-
nn.init.constant_(m.weight, 1)
51-
nn.init.constant_(m.bias, 0)
52-
elif isinstance(m, nn.Linear):
53-
nn.init.normal_(m.weight, 0, 0.01)
54-
nn.init.constant_(m.bias, 0)
42+
5543

5644
def test():
5745
net = LeNet5()

CIFAR10_code/nets/MobileNetv2.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,6 @@ def _make_layers(self, in_channels):
6868
in_channels = out_channels
6969
return nn.Sequential(*layers)
7070

71-
def init_weight(self):
72-
for w in self.modules():
73-
if isinstance(w, nn.Conv2d):
74-
nn.init.kaiming_normal_(w.weight, mode='fan_out')
75-
if w.bias is not None:
76-
nn.init.zeros_(w.bias)
77-
elif isinstance(w, nn.BatchNorm2d):
78-
nn.init.ones_(w.weight)
79-
nn.init.zeros_(w.bias)
80-
elif isinstance(w, nn.Linear):
81-
nn.init.normal_(w.weight, 0, 0.01)
82-
nn.init.zeros_(w.bias)
8371

8472
def forward(self, x):
8573
out = self.relu6(self.bn1(self.conv1(x)))

CIFAR10_code/train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
'''Train CIFAR10 with PyTorch.'''
2-
import imp
32
import torch
43
import torch.nn as nn
54
import torch.optim as optim

0 commit comments

Comments
 (0)