1212from torch .nn import functional as F
1313from torch .utils import model_zoo
1414
15-
1615########################################################################
1716############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
1817########################################################################
2423 'num_classes' , 'width_coefficient' , 'depth_coefficient' ,
2524 'depth_divisor' , 'min_depth' , 'drop_connect_rate' , 'image_size' ])
2625
27-
2826# Parameters for an individual model block
2927BlockArgs = collections .namedtuple ('BlockArgs' , [
3028 'kernel_size' , 'num_repeat' , 'input_filters' , 'output_filters' ,
3129 'expand_ratio' , 'id_skip' , 'stride' , 'se_ratio' ])
3230
33-
3431# Change namedtuple defaults
3532GlobalParams .__new__ .__defaults__ = (None ,) * len (GlobalParams ._fields )
3633BlockArgs .__new__ .__defaults__ = (None ,) * len (BlockArgs ._fields )
3734
3835
39- def relu_fn (x ):
40- """ Swish activation function """
41- return x * torch .sigmoid (x )
36+ class SwishImplementation (torch .autograd .Function ):
37+ @staticmethod
38+ def forward (ctx , i ):
39+ result = i * torch .sigmoid (i )
40+ ctx .save_for_backward (i )
41+ return result
42+
43+ @staticmethod
44+ def backward (ctx , grad_output ):
45+ i = ctx .saved_variables [0 ]
46+ sigmoid_i = torch .sigmoid (i )
47+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i )))
48+
49+
50+ class Swish (nn .Module ):
51+ @staticmethod
52+ def forward (x ):
53+ return SwishImplementation .apply (x )
54+
55+
56+ relu_fn = Swish ()
4257
4358
4459def round_filters (filters , global_params ):
@@ -84,11 +99,13 @@ def get_same_padding_conv2d(image_size=None):
8499 else :
85100 return partial (Conv2dStaticSamePadding , image_size = image_size )
86101
102+
87103class Conv2dDynamicSamePadding (nn .Conv2d ):
88104 """ 2D Convolutions like TensorFlow, for a dynamic image size """
105+
89106 def __init__ (self , in_channels , out_channels , kernel_size , stride = 1 , dilation = 1 , groups = 1 , bias = True ):
90107 super ().__init__ (in_channels , out_channels , kernel_size , stride , 0 , dilation , groups , bias )
91- self .stride = self .stride if len (self .stride ) == 2 else [self .stride [0 ]]* 2
108+ self .stride = self .stride if len (self .stride ) == 2 else [self .stride [0 ]] * 2
92109
93110 def forward (self , x ):
94111 ih , iw = x .size ()[- 2 :]
@@ -98,12 +115,13 @@ def forward(self, x):
98115 pad_h = max ((oh - 1 ) * self .stride [0 ] + (kh - 1 ) * self .dilation [0 ] + 1 - ih , 0 )
99116 pad_w = max ((ow - 1 ) * self .stride [1 ] + (kw - 1 ) * self .dilation [1 ] + 1 - iw , 0 )
100117 if pad_h > 0 or pad_w > 0 :
101- x = F .pad (x , [pad_w // 2 , pad_w - pad_w // 2 , pad_h // 2 , pad_h - pad_h // 2 ])
118+ x = F .pad (x , [pad_w // 2 , pad_w - pad_w // 2 , pad_h // 2 , pad_h - pad_h // 2 ])
102119 return F .conv2d (x , self .weight , self .bias , self .stride , self .padding , self .dilation , self .groups )
103120
104121
105122class Conv2dStaticSamePadding (nn .Conv2d ):
106123 """ 2D Convolutions like TensorFlow, for a fixed image size"""
124+
107125 def __init__ (self , in_channels , out_channels , kernel_size , image_size = None , ** kwargs ):
108126 super ().__init__ (in_channels , out_channels , kernel_size , ** kwargs )
109127 self .stride = self .stride if len (self .stride ) == 2 else [self .stride [0 ]] * 2
@@ -128,7 +146,7 @@ def forward(self, x):
128146
129147
130148class Identity (nn .Module ):
131- def __init__ (self ,):
149+ def __init__ (self , ):
132150 super (Identity , self ).__init__ ()
133151
134152 def forward (self , input ):
@@ -286,6 +304,7 @@ def get_model_params(model_name, override_params):
286304 'efficientnet-b7' : 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth' ,
287305}
288306
307+
289308def load_pretrained_weights (model , model_name , load_fc = True ):
290309 """ Loads pretrained weights, and downloads if loading for the first time. """
291310 state_dict = model_zoo .load_url (url_map [model_name ])
@@ -295,5 +314,5 @@ def load_pretrained_weights(model, model_name, load_fc=True):
295314 state_dict .pop ('_fc.weight' )
296315 state_dict .pop ('_fc.bias' )
297316 res = model .load_state_dict (state_dict , strict = False )
298- assert str (res .missing_keys ) == str (['_fc.weight' , '_fc.bias' ]), 'issue loading pretrained weights'
317+ assert set (res .missing_keys ) == set (['_fc.weight' , '_fc.bias' ]), 'issue loading pretrained weights'
299318 print ('Loaded pretrained weights for {}' .format (model_name ))
0 commit comments