Skip to content

Commit bd3c392

Browse files
committed
Added static padding
1 parent 125e823 commit bd3c392

File tree

3 files changed

+72
-24
lines changed

3 files changed

+72
-24
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ example/test*
114114
*.pth*
115115
examples/imagenet/data/
116116
!examples/imagenet/data/README.md
117+
tmp
118+
tf_to_pytorch/pretrained_tensorflow
119+
!tf_to_pytorch/pretrained_tensorflow/download.sh
120+
examples/imagenet/run.sh
117121

118122

119123

efficientnet_pytorch/model.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
round_filters,
88
round_repeats,
99
drop_connect,
10-
Conv2dSamePadding,
10+
get_same_padding_conv2d,
1111
get_model_params,
1212
efficientnet_params,
1313
load_pretrained_weights,
@@ -33,30 +33,33 @@ def __init__(self, block_args, global_params):
3333
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
3434
self.id_skip = block_args.id_skip # skip connection and drop connect
3535

36+
# Get static or dynamic convolution depending on image size
37+
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
38+
3639
# Expansion phase
3740
inp = self._block_args.input_filters # number of input channels
3841
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
3942
if self._block_args.expand_ratio != 1:
40-
self._expand_conv = Conv2dSamePadding(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
43+
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
4144
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
4245

4346
# Depthwise convolution phase
4447
k = self._block_args.kernel_size
4548
s = self._block_args.stride
46-
self._depthwise_conv = Conv2dSamePadding(
49+
self._depthwise_conv = Conv2d(
4750
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
4851
kernel_size=k, stride=s, bias=False)
4952
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
5053

5154
# Squeeze and Excitation layer, if desired
5255
if self.has_se:
5356
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
54-
self._se_reduce = Conv2dSamePadding(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
55-
self._se_expand = Conv2dSamePadding(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
57+
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
58+
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
5659

5760
# Output phase
5861
final_oup = self._block_args.output_filters
59-
self._project_conv = Conv2dSamePadding(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
62+
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
6063
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
6164

6265
def forward(self, inputs, drop_connect_rate=None):
@@ -109,14 +112,17 @@ def __init__(self, blocks_args=None, global_params=None):
109112
self._global_params = global_params
110113
self._blocks_args = blocks_args
111114

115+
# Get static or dynamic convolution depending on image size
116+
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
117+
112118
# Batch norm parameters
113119
bn_mom = 1 - self._global_params.batch_norm_momentum
114120
bn_eps = self._global_params.batch_norm_epsilon
115121

116122
# Stem
117123
in_channels = 3 # rgb
118124
out_channels = round_filters(32, self._global_params) # number of output channels
119-
self._conv_stem = Conv2dSamePadding(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
125+
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
120126
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
121127

122128
# Build blocks
@@ -140,7 +146,7 @@ def __init__(self, blocks_args=None, global_params=None):
140146
# Head
141147
in_channels = block_args.output_filters # output of final block
142148
out_channels = round_filters(1280, self._global_params)
143-
self._conv_head = Conv2dSamePadding(in_channels, out_channels, kernel_size=1, bias=False)
149+
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
144150
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
145151

146152
# Final linear layer
@@ -158,7 +164,10 @@ def extract_features(self, inputs):
158164
drop_connect_rate = self._global_params.drop_connect_rate
159165
if drop_connect_rate:
160166
drop_connect_rate *= float(idx) / len(self._blocks)
161-
x = block(x, drop_connect_rate)
167+
x = block(x, drop_connect_rate=drop_connect_rate)
168+
169+
# Head
170+
x = relu_fn(self._bn1(self._conv_head(x)))
162171

163172
return x
164173

@@ -168,8 +177,7 @@ def forward(self, inputs):
168177
# Convolution layers
169178
x = self.extract_features(inputs)
170179

171-
# Head
172-
x = relu_fn(self._bn1(self._conv_head(x)))
180+
# Pooling and final linear layer
173181
x = F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1)
174182
if self._dropout:
175183
x = F.dropout(x, p=self._dropout, training=self.training)
@@ -183,9 +191,9 @@ def from_name(cls, model_name, override_params=None):
183191
return EfficientNet(blocks_args, global_params)
184192

185193
@classmethod
186-
def from_pretrained(cls, model_name):
187-
model = EfficientNet.from_name(model_name)
188-
load_pretrained_weights(model, model_name)
194+
def from_pretrained(cls, model_name, num_classes=1000):
195+
model = EfficientNet.from_name(model_name, override_params={'num_classes': 1000})
196+
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
189197
return model
190198

191199
@classmethod

efficientnet_pytorch/utils.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import re
77
import math
88
import collections
9+
from functools import partial
910
import torch
1011
from torch import nn
1112
from torch.nn import functional as F
@@ -21,7 +22,7 @@
2122
GlobalParams = collections.namedtuple('GlobalParams', [
2223
'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate',
2324
'num_classes', 'width_coefficient', 'depth_coefficient',
24-
'depth_divisor', 'min_depth', 'drop_connect_rate',])
25+
'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])
2526

2627

2728
# Parameters for an individual model block
@@ -75,8 +76,16 @@ def drop_connect(inputs, p, training):
7576
return output
7677

7778

78-
class Conv2dSamePadding(nn.Conv2d):
79-
""" 2D Convolutions like TensorFlow """
79+
def get_same_padding_conv2d(image_size=None):
80+
""" Chooses static padding if you have specified an image size, and dynamic padding otherwise.
81+
Static padding is necessary for ONNX exporting of models. """
82+
if image_size is None:
83+
return Conv2dDynamicSamePadding
84+
else:
85+
return partial(Conv2dStaticSamePadding, image_size=image_size)
86+
87+
class Conv2dDynamicSamePadding(nn.Conv2d):
88+
""" 2D Convolutions like TensorFlow, for a dynamic image size """
8089
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
8190
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
8291
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]]*2
@@ -93,6 +102,31 @@ def forward(self, x):
93102
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
94103

95104

105+
class Conv2dStaticSamePadding(nn.Conv2d):
106+
""" 2D Convolutions like TensorFlow, for a fixed image size"""
107+
def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
108+
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
109+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
110+
111+
# Calculate padding based on image size and save it
112+
assert image_size is not None
113+
ih, iw = image_size if type(image_size) == list else [image_size, image_size]
114+
kh, kw = self.weight.size()[-2:]
115+
sh, sw = self.stride
116+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
117+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
118+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
119+
if pad_h > 0 or pad_w > 0:
120+
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
121+
else:
122+
self.static_padding = nn.Identity()
123+
124+
def forward(self, x):
125+
x = self.static_padding(x)
126+
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
127+
return x
128+
129+
96130
########################################################################
97131
############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
98132
########################################################################
@@ -189,8 +223,8 @@ def encode(blocks_args):
189223
return block_strings
190224

191225

192-
def efficientnet(width_coefficient=None, depth_coefficient=None,
193-
dropout_rate=0.2, drop_connect_rate=0.2):
226+
def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2,
227+
drop_connect_rate=0.2, image_size=None, num_classes=1000):
194228
""" Creates a efficientnet model. """
195229

196230
blocks_args = [
@@ -207,11 +241,12 @@ def efficientnet(width_coefficient=None, depth_coefficient=None,
207241
dropout_rate=dropout_rate,
208242
drop_connect_rate=drop_connect_rate,
209243
# data_format='channels_last', # removed, this is always true in PyTorch
210-
num_classes=1000,
244+
num_classes=num_classes,
211245
width_coefficient=width_coefficient,
212246
depth_coefficient=depth_coefficient,
213247
depth_divisor=8,
214-
min_depth=None
248+
min_depth=None,
249+
image_size=image_size,
215250
)
216251

217252
return blocks_args, global_params
@@ -220,9 +255,10 @@ def efficientnet(width_coefficient=None, depth_coefficient=None,
220255
def get_model_params(model_name, override_params):
221256
""" Get the block args and global params for a given model """
222257
if model_name.startswith('efficientnet'):
223-
w, d, _, p = efficientnet_params(model_name)
258+
w, d, s, p = efficientnet_params(model_name)
224259
# note: all models have drop connect rate = 0.2
225-
blocks_args, global_params = efficientnet(width_coefficient=w, depth_coefficient=d, dropout_rate=p)
260+
blocks_args, global_params = efficientnet(
261+
width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
226262
else:
227263
raise NotImplementedError('model name is not pre-defined: %s' % model_name)
228264
if override_params:
@@ -240,7 +276,7 @@ def get_model_params(model_name, override_params):
240276
'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet-b5-586e6cc6.pth',
241277
}
242278

243-
def load_pretrained_weights(model, model_name):
279+
def load_pretrained_weights(model, model_name, load_fc=True):
244280
""" Loads pretrained weights, and downloads if loading for the first time. """
245281
state_dict = model_zoo.load_url(url_map[model_name])
246282
model.load_state_dict(state_dict)

0 commit comments

Comments
 (0)