Skip to content

Commit 03c9da4

Browse files
committed
增加权重初始化函数
1 parent 7f88a15 commit 03c9da4

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

CIFAR10_code/nets/MobileNetv1.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,29 @@ def forward(self,x):
5656
x = x.view(x.size()[0],-1)
5757
x = self.linear(x)
5858
return x
59-
59+
60+
def init_weight(self):
61+
for w in self.modules():
62+
if isinstance(w, nn.Conv2d):
63+
nn.init.kaiming_normal_(w.weight, mode='fan_out')
64+
if w.bias is not None:
65+
nn.init.zeros_(w.bias)
66+
elif isinstance(w, nn.BatchNorm2d):
67+
nn.init.ones_(w.weight)
68+
nn.init.zeros_(w.bias)
69+
elif isinstance(w, nn.Linear):
70+
nn.init.normal_(w.weight, 0, 0.01)
71+
nn.init.zeros_(w.bias)
72+
73+
6074
def test():
6175
net = MobileNet()
6276
x = torch.randn(2,3,32,32)
6377
y = net(x)
6478
print(y.size())
65-
from torchsummary import summary
79+
from torchinfo import summary
6680
device = 'cuda' if torch.cuda.is_available() else 'cpu'
6781
net = net.to(device)
68-
summary(net,(3,32,32))
82+
summary(net,(32,3,32,32))
6983

70-
# test()
84+
test()

0 commit comments

Comments
 (0)