@@ -211,8 +211,7 @@ class EfficientNet(nn.Module):
211211 def __init__ (self , block_args , num_classes = 1000 , num_features = 1280 , in_chans = 3 , stem_size = 32 ,
212212 channel_multiplier = 1.0 , channel_divisor = 8 , channel_min = None ,
213213 pad_type = '' , act_layer = nn .ReLU , drop_rate = 0. , drop_connect_rate = 0. ,
214- se_kwargs = None , norm_layer = nn .BatchNorm2d , norm_kwargs = None ,
215- global_pool = 'avg' , weight_init = 'goog' ):
214+ se_kwargs = None , norm_layer = nn .BatchNorm2d , norm_kwargs = None , global_pool = 'avg' ):
216215 super (EfficientNet , self ).__init__ ()
217216 norm_kwargs = norm_kwargs or {}
218217
@@ -245,11 +244,7 @@ def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3,
245244 # Classifier
246245 self .classifier = nn .Linear (self .num_features * self .global_pool .feat_mult (), self .num_classes )
247246
248- for m in self .modules ():
249- if weight_init == 'goog' :
250- efficientnet_init_goog (m )
251- else :
252- efficientnet_init_default (m )
247+ efficientnet_init_weights (self )
253248
254249 def as_sequential (self ):
255250 layers = [self .conv_stem , self .bn1 , self .act1 ]
@@ -262,14 +257,10 @@ def get_classifier(self):
262257 return self .classifier
263258
264259 def reset_classifier (self , num_classes , global_pool = 'avg' ):
265- self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
266260 self .num_classes = num_classes
267- del self .classifier
268- if num_classes :
269- self .classifier = nn .Linear (
270- self .num_features * self .global_pool .feat_mult (), num_classes )
271- else :
272- self .classifier = None
261+ self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
262+ self .classifier = nn .Linear (
263+ self .num_features * self .global_pool .feat_mult (), num_classes ) if num_classes else None
273264
274265 def forward_features (self , x ):
275266 x = self .conv_stem (x )
@@ -300,7 +291,7 @@ class EfficientNetFeatures(nn.Module):
300291 def __init__ (self , block_args , out_indices = (0 , 1 , 2 , 3 , 4 ), feature_location = 'pre_pwl' ,
301292 in_chans = 3 , stem_size = 32 , channel_multiplier = 1.0 , channel_divisor = 8 , channel_min = None ,
302293 output_stride = 32 , pad_type = '' , act_layer = nn .ReLU , drop_rate = 0. , drop_connect_rate = 0. ,
303- se_kwargs = None , norm_layer = nn .BatchNorm2d , norm_kwargs = None , weight_init = 'goog' ):
294+ se_kwargs = None , norm_layer = nn .BatchNorm2d , norm_kwargs = None ):
304295 super (EfficientNetFeatures , self ).__init__ ()
305296 norm_kwargs = norm_kwargs or {}
306297
@@ -326,12 +317,7 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pr
326317 self .feature_info = builder .features # builder provides info about feature channels for each block
327318 self ._in_chs = builder .in_chs
328319
329- for m in self .modules ():
330- if weight_init == 'goog' :
331- efficientnet_init_goog (m )
332- else :
333- efficientnet_init_default (m )
334-
320+ efficientnet_init_weights (self )
335321 if _DEBUG :
336322 for k , v in self .feature_info .items ():
337323 print ('Feature idx: {}: Name: {}, Channels: {}' .format (k , v ['name' ], v ['num_chs' ]))
0 commit comments