@@ -79,7 +79,7 @@ class ResNet(nn.Module):
7979 -> (16, 16, 128) -> [Res3] -> (8, 8, 256) ->[Res4] -> (4, 4, 512) -> [AvgPool]
8080 -> (1, 1, 512) -> [Reshape] -> (512) -> [Linear] -> (10)
8181 """
82- def __init__ (self , block , num_blocks , num_classes = 10 , verbose = False ):
82+ def __init__ (self , block , num_blocks , num_classes = 10 , verbose = False , init_weights = True ):
8383 super (ResNet , self ).__init__ ()
8484 self .verbose = verbose
8585 self .in_channels = 64
@@ -97,7 +97,11 @@ def __init__(self, block, num_blocks, num_classes=10, verbose = False):
9797 # 所以这里用了 4 x 4 的平均池化
9898 self .avg_pool = nn .AvgPool2d (kernel_size = 4 )
9999 self .classifer = nn .Linear (512 * block .expansion , num_classes )
100-
100+
101+ if init_weights :
102+ self ._initialize_weights ()
103+
104+
101105 def _make_layer (self , block , out_channels , num_blocks , stride ):
102106 # 第一个block要进行降采样
103107 strides = [stride ] + [1 ] * (num_blocks - 1 )
@@ -130,6 +134,18 @@ def forward(self, x):
130134 out = self .classifer (out )
131135 return out
132136
137+ def _initialize_weights (self ):
138+ for m in self .modules ():
139+ if isinstance (m , nn .Conv2d ):
140+ nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' , nonlinearity = 'relu' )
141+ if m .bias is not None :
142+ nn .init .constant_ (m .bias , 0 )
143+ elif isinstance (m , nn .BatchNorm2d ):
144+ nn .init .constant_ (m .weight , 1 )
145+ nn .init .constant_ (m .bias , 0 )
146+ elif isinstance (m , nn .Linear ):
147+ nn .init .normal_ (m .weight , 0 , 0.01 )
148+ nn .init .constant_ (m .bias , 0 )
133149def ResNet18 (verbose = False ):
134150 return ResNet (BasicBlock , [2 ,2 ,2 ,2 ],verbose = verbose )
135151
0 commit comments