77
88# 定义2012的AlexNet
99class AlexNet (nn .Module ):
10- def __init__ (self ,num_classes = 10 , init_weights = True ):
10+ def __init__ (self ,num_classes = 10 ):
1111 super (AlexNet ,self ).__init__ ()
1212 # 五个卷积层 输入 32 * 32 * 3
1313 self .conv1 = nn .Sequential (
@@ -43,8 +43,7 @@ def __init__(self,num_classes=10, init_weights=True):
4343 nn .ReLU (),
4444 nn .Linear (84 ,num_classes )
4545 )
46- if init_weights :
47- self ._initialize_weights ()
46+
4847 def forward (self ,x ):
4948 x = self .conv1 (x )
5049 x = self .conv2 (x )
@@ -55,18 +54,6 @@ def forward(self,x):
5554 x = self .fc (x )
5655 return x
5756
58- def _initialize_weights (self ):
59- for m in self .modules ():
60- if isinstance (m , nn .Conv2d ):
61- nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' , nonlinearity = 'relu' )
62- if m .bias is not None :
63- nn .init .constant_ (m .bias , 0 )
64- elif isinstance (m , nn .BatchNorm2d ):
65- nn .init .constant_ (m .weight , 1 )
66- nn .init .constant_ (m .bias , 0 )
67- elif isinstance (m , nn .Linear ):
68- nn .init .normal_ (m .weight , 0 , 0.01 )
69- nn .init .constant_ (m .bias , 0 )
7057def test ():
7158 net = AlexNet ()
7259 x = torch .randn (2 ,3 ,32 ,32 )
0 commit comments