66import re
77import math
88import collections
9+ from functools import partial
910import torch
1011from torch import nn
1112from torch .nn import functional as F
2122GlobalParams = collections .namedtuple ('GlobalParams' , [
2223 'batch_norm_momentum' , 'batch_norm_epsilon' , 'dropout_rate' ,
2324 'num_classes' , 'width_coefficient' , 'depth_coefficient' ,
24- 'depth_divisor' , 'min_depth' , 'drop_connect_rate' ,])
25+ 'depth_divisor' , 'min_depth' , 'drop_connect_rate' , 'image_size' ])
2526
2627
2728# Parameters for an individual model block
@@ -75,8 +76,16 @@ def drop_connect(inputs, p, training):
7576 return output
7677
7778
78- class Conv2dSamePadding (nn .Conv2d ):
79- """ 2D Convolutions like TensorFlow """
79+ def get_same_padding_conv2d (image_size = None ):
80+ """ Chooses static padding if you have specified an image size, and dynamic padding otherwise.
81+ Static padding is necessary for ONNX exporting of models. """
82+ if image_size is None :
83+ return Conv2dDynamicSamePadding
84+ else :
85+ return partial (Conv2dStaticSamePadding , image_size = image_size )
86+
87+ class Conv2dDynamicSamePadding (nn .Conv2d ):
88+ """ 2D Convolutions like TensorFlow, for a dynamic image size """
8089 def __init__ (self , in_channels , out_channels , kernel_size , stride = 1 , dilation = 1 , groups = 1 , bias = True ):
8190 super ().__init__ (in_channels , out_channels , kernel_size , stride , 0 , dilation , groups , bias )
8291 self .stride = self .stride if len (self .stride ) == 2 else [self .stride [0 ]]* 2
@@ -93,6 +102,31 @@ def forward(self, x):
93102 return F .conv2d (x , self .weight , self .bias , self .stride , self .padding , self .dilation , self .groups )
94103
95104
105+ class Conv2dStaticSamePadding (nn .Conv2d ):
106+ """ 2D Convolutions like TensorFlow, for a fixed image size"""
107+ def __init__ (self , in_channels , out_channels , kernel_size , image_size = None , ** kwargs ):
108+ super ().__init__ (in_channels , out_channels , kernel_size , ** kwargs )
109+ self .stride = self .stride if len (self .stride ) == 2 else [self .stride [0 ]] * 2
110+
111+ # Calculate padding based on image size and save it
112+ assert image_size is not None
113+ ih , iw = image_size if type (image_size ) == list else [image_size , image_size ]
114+ kh , kw = self .weight .size ()[- 2 :]
115+ sh , sw = self .stride
116+ oh , ow = math .ceil (ih / sh ), math .ceil (iw / sw )
117+ pad_h = max ((oh - 1 ) * self .stride [0 ] + (kh - 1 ) * self .dilation [0 ] + 1 - ih , 0 )
118+ pad_w = max ((ow - 1 ) * self .stride [1 ] + (kw - 1 ) * self .dilation [1 ] + 1 - iw , 0 )
119+ if pad_h > 0 or pad_w > 0 :
120+ self .static_padding = nn .ZeroPad2d ((pad_w // 2 , pad_w - pad_w // 2 , pad_h // 2 , pad_h - pad_h // 2 ))
121+ else :
122+ self .static_padding = nn .Identity ()
123+
124+ def forward (self , x ):
125+ x = self .static_padding (x )
126+ x = F .conv2d (x , self .weight , self .bias , self .stride , self .padding , self .dilation , self .groups )
127+ return x
128+
129+
96130########################################################################
97131############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
98132########################################################################
@@ -189,8 +223,8 @@ def encode(blocks_args):
189223 return block_strings
190224
191225
192- def efficientnet (width_coefficient = None , depth_coefficient = None ,
193- dropout_rate = 0.2 , drop_connect_rate = 0.2 ):
226+ def efficientnet (width_coefficient = None , depth_coefficient = None , dropout_rate = 0.2 ,
227+ drop_connect_rate = 0.2 , image_size = None , num_classes = 1000 ):
194228 """ Creates a efficientnet model. """
195229
196230 blocks_args = [
@@ -207,11 +241,12 @@ def efficientnet(width_coefficient=None, depth_coefficient=None,
207241 dropout_rate = dropout_rate ,
208242 drop_connect_rate = drop_connect_rate ,
209243 # data_format='channels_last', # removed, this is always true in PyTorch
210- num_classes = 1000 ,
244+ num_classes = num_classes ,
211245 width_coefficient = width_coefficient ,
212246 depth_coefficient = depth_coefficient ,
213247 depth_divisor = 8 ,
214- min_depth = None
248+ min_depth = None ,
249+ image_size = image_size ,
215250 )
216251
217252 return blocks_args , global_params
@@ -220,9 +255,10 @@ def efficientnet(width_coefficient=None, depth_coefficient=None,
220255def get_model_params (model_name , override_params ):
221256 """ Get the block args and global params for a given model """
222257 if model_name .startswith ('efficientnet' ):
223- w , d , _ , p = efficientnet_params (model_name )
258+ w , d , s , p = efficientnet_params (model_name )
224259 # note: all models have drop connect rate = 0.2
225- blocks_args , global_params = efficientnet (width_coefficient = w , depth_coefficient = d , dropout_rate = p )
260+ blocks_args , global_params = efficientnet (
261+ width_coefficient = w , depth_coefficient = d , dropout_rate = p , image_size = s )
226262 else :
227263 raise NotImplementedError ('model name is not pre-defined: %s' % model_name )
228264 if override_params :
@@ -240,7 +276,7 @@ def get_model_params(model_name, override_params):
240276 'efficientnet-b5' : 'http://storage.googleapis.com/public-models/efficientnet-b5-586e6cc6.pth' ,
241277}
242278
243- def load_pretrained_weights (model , model_name ):
279+ def load_pretrained_weights (model , model_name , load_fc = True ):
244280 """ Loads pretrained weights, and downloads if loading for the first time. """
245281 state_dict = model_zoo .load_url (url_map [model_name ])
246282 model .load_state_dict (state_dict )
0 commit comments