Skip to content

Commit 50a2bf2

Browse files
authored
Merge pull request #44 from lukemelas/add-export-and-improve-code
Add export and improve code
2 parents 125e823 + 985d0e8 commit 50a2bf2

File tree

6 files changed

+130
-39
lines changed

6 files changed

+130
-39
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ example/test*
114114
*.pth*
115115
examples/imagenet/data/
116116
!examples/imagenet/data/README.md
117+
tmp
118+
tf_to_pytorch/pretrained_tensorflow
119+
!tf_to_pytorch/pretrained_tensorflow/download.sh
120+
examples/imagenet/run.sh
117121

118122

119123

README.md

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,28 @@
11
# EfficientNet PyTorch
22

3+
### Update (June 29, 2019)
4+
5+
_Upgrade the pip package with_ `pip install --upgrade efficientnet-pytorch`
6+
7+
This update adds easy model exporting ([#20](https://github.com/lukemelas/EfficientNet-PyTorch/issues/20)) and feature extraction ([#38](https://github.com/lukemelas/EfficientNet-PyTorch/issues/38)).
8+
9+
* [Example: Export to ONNX](#example-export)
10+
* [Example: Extract features](#example-feature-extraction)
11+
* Also: fixed a CUDA/CPU bug ([#32](https://github.com/lukemelas/EfficientNet-PyTorch/issues/32))
12+
13+
It is also now incredibly simple to load a pretrained model with a new number of classes for transfer learning:
14+
```python
15+
model = EfficientNet.from_pretrained('efficientnet-b1', num_classes=23)
16+
```
17+
18+
319
### Update (June 23, 2019)
420

521
The B4 and B5 models are now available. Their usage is identical to the other models:
622
```python
723
from efficientnet_pytorch import EfficientNet
824
model = EfficientNet.from_pretrained('efficientnet-b4')
925
```
10-
Upgrade the pip package with `pip install --upgrade efficientnet-pytorch`.
1126

1227
### Overview
1328
This repository contains an op-for-op PyTorch reimplementation of [EfficientNet](https://arxiv.org/abs/1905.11946), along with pre-trained models and examples.
@@ -32,6 +47,7 @@ _Upcoming features_: In the next few days, you will be able to:
3247
* [Load pretrained models](#loading-pretrained-models)
3348
* [Example: Classify](#example-classification)
3449
* [Example: Extract features](#example-feature-extraction)
50+
* [Example: Export to ONNX](#example-export)
3551
6. [Contributing](#contributing)
3652

3753
### About EfficientNet
@@ -160,9 +176,25 @@ model = EfficientNet.from_pretrained('efficientnet-b0')
160176
print(img.shape) # torch.Size([1, 3, 224, 224])
161177

162178
features = model.extract_features(img)
163-
print(features.shape) # torch.Size([1, 320, 7, 7])
179+
print(features.shape) # torch.Size([1, 1280, 7, 7])
164180
```
165181

182+
#### Example: Export to ONNX
183+
184+
Exporting to ONNX for deploying to production is now simple:
185+
```python
186+
import torch
187+
from efficientnet_pytorch import EfficientNet
188+
189+
model = EfficientNet.from_pretrained('efficientnet-b1')
190+
dummy_input = torch.randn(10, 3, 240, 240)
191+
192+
torch.onnx.export(model, dummy_input, "test-b1.onnx", verbose=True)
193+
```
194+
195+
[Here](https://colab.research.google.com/drive/1rOAEXeXHaA8uo3aG2YcFDHItlRJMV0VP) is a Colab example.
196+
197+
166198
#### ImageNet
167199

168200
See `examples/imagenet` for details about evaluating on ImageNet.

efficientnet_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.1.0"
1+
__version__ = "0.3.0"
22
from .model import EfficientNet
33
from .utils import (
44
GlobalParams,

efficientnet_pytorch/model.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
round_filters,
88
round_repeats,
99
drop_connect,
10-
Conv2dSamePadding,
10+
get_same_padding_conv2d,
1111
get_model_params,
1212
efficientnet_params,
1313
load_pretrained_weights,
@@ -33,30 +33,33 @@ def __init__(self, block_args, global_params):
3333
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
3434
self.id_skip = block_args.id_skip # skip connection and drop connect
3535

36+
# Get static or dynamic convolution depending on image size
37+
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
38+
3639
# Expansion phase
3740
inp = self._block_args.input_filters # number of input channels
3841
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
3942
if self._block_args.expand_ratio != 1:
40-
self._expand_conv = Conv2dSamePadding(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
43+
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
4144
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
4245

4346
# Depthwise convolution phase
4447
k = self._block_args.kernel_size
4548
s = self._block_args.stride
46-
self._depthwise_conv = Conv2dSamePadding(
49+
self._depthwise_conv = Conv2d(
4750
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
4851
kernel_size=k, stride=s, bias=False)
4952
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
5053

5154
# Squeeze and Excitation layer, if desired
5255
if self.has_se:
5356
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
54-
self._se_reduce = Conv2dSamePadding(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
55-
self._se_expand = Conv2dSamePadding(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
57+
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
58+
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
5659

5760
# Output phase
5861
final_oup = self._block_args.output_filters
59-
self._project_conv = Conv2dSamePadding(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
62+
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
6063
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
6164

6265
def forward(self, inputs, drop_connect_rate=None):
@@ -109,14 +112,17 @@ def __init__(self, blocks_args=None, global_params=None):
109112
self._global_params = global_params
110113
self._blocks_args = blocks_args
111114

115+
# Get static or dynamic convolution depending on image size
116+
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
117+
112118
# Batch norm parameters
113119
bn_mom = 1 - self._global_params.batch_norm_momentum
114120
bn_eps = self._global_params.batch_norm_epsilon
115121

116122
# Stem
117123
in_channels = 3 # rgb
118124
out_channels = round_filters(32, self._global_params) # number of output channels
119-
self._conv_stem = Conv2dSamePadding(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
125+
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
120126
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
121127

122128
# Build blocks
@@ -140,7 +146,7 @@ def __init__(self, blocks_args=None, global_params=None):
140146
# Head
141147
in_channels = block_args.output_filters # output of final block
142148
out_channels = round_filters(1280, self._global_params)
143-
self._conv_head = Conv2dSamePadding(in_channels, out_channels, kernel_size=1, bias=False)
149+
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
144150
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
145151

146152
# Final linear layer
@@ -158,7 +164,10 @@ def extract_features(self, inputs):
158164
drop_connect_rate = self._global_params.drop_connect_rate
159165
if drop_connect_rate:
160166
drop_connect_rate *= float(idx) / len(self._blocks)
161-
x = block(x, drop_connect_rate)
167+
x = block(x, drop_connect_rate=drop_connect_rate)
168+
169+
# Head
170+
x = relu_fn(self._bn1(self._conv_head(x)))
162171

163172
return x
164173

@@ -168,8 +177,7 @@ def forward(self, inputs):
168177
# Convolution layers
169178
x = self.extract_features(inputs)
170179

171-
# Head
172-
x = relu_fn(self._bn1(self._conv_head(x)))
180+
# Pooling and final linear layer
173181
x = F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1)
174182
if self._dropout:
175183
x = F.dropout(x, p=self._dropout, training=self.training)
@@ -183,9 +191,9 @@ def from_name(cls, model_name, override_params=None):
183191
return EfficientNet(blocks_args, global_params)
184192

185193
@classmethod
186-
def from_pretrained(cls, model_name):
187-
model = EfficientNet.from_name(model_name)
188-
load_pretrained_weights(model, model_name)
194+
def from_pretrained(cls, model_name, num_classes=1000):
195+
model = EfficientNet.from_name(model_name, override_params={'num_classes': num_classes})
196+
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
189197
return model
190198

191199
@classmethod

efficientnet_pytorch/utils.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import re
77
import math
88
import collections
9+
from functools import partial
910
import torch
1011
from torch import nn
1112
from torch.nn import functional as F
@@ -21,7 +22,7 @@
2122
GlobalParams = collections.namedtuple('GlobalParams', [
2223
'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate',
2324
'num_classes', 'width_coefficient', 'depth_coefficient',
24-
'depth_divisor', 'min_depth', 'drop_connect_rate',])
25+
'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])
2526

2627

2728
# Parameters for an individual model block
@@ -75,8 +76,16 @@ def drop_connect(inputs, p, training):
7576
return output
7677

7778

78-
class Conv2dSamePadding(nn.Conv2d):
79-
""" 2D Convolutions like TensorFlow """
79+
def get_same_padding_conv2d(image_size=None):
80+
""" Chooses static padding if you have specified an image size, and dynamic padding otherwise.
81+
Static padding is necessary for ONNX exporting of models. """
82+
if image_size is None:
83+
return Conv2dDynamicSamePadding
84+
else:
85+
return partial(Conv2dStaticSamePadding, image_size=image_size)
86+
87+
class Conv2dDynamicSamePadding(nn.Conv2d):
88+
""" 2D Convolutions like TensorFlow, for a dynamic image size """
8089
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
8190
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
8291
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]]*2
@@ -93,6 +102,39 @@ def forward(self, x):
93102
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
94103

95104

105+
class Conv2dStaticSamePadding(nn.Conv2d):
106+
""" 2D Convolutions like TensorFlow, for a fixed image size"""
107+
def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
108+
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
109+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
110+
111+
# Calculate padding based on image size and save it
112+
assert image_size is not None
113+
ih, iw = image_size if type(image_size) == list else [image_size, image_size]
114+
kh, kw = self.weight.size()[-2:]
115+
sh, sw = self.stride
116+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
117+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
118+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
119+
if pad_h > 0 or pad_w > 0:
120+
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
121+
else:
122+
self.static_padding = Identity()
123+
124+
def forward(self, x):
125+
x = self.static_padding(x)
126+
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
127+
return x
128+
129+
130+
class Identity(nn.Module):
131+
def __init__(self,):
132+
super(Identity, self).__init__()
133+
134+
def forward(self, input):
135+
return input
136+
137+
96138
########################################################################
97139
############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
98140
########################################################################
@@ -189,8 +231,8 @@ def encode(blocks_args):
189231
return block_strings
190232

191233

192-
def efficientnet(width_coefficient=None, depth_coefficient=None,
193-
dropout_rate=0.2, drop_connect_rate=0.2):
234+
def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2,
235+
drop_connect_rate=0.2, image_size=None, num_classes=1000):
194236
""" Creates a efficientnet model. """
195237

196238
blocks_args = [
@@ -207,11 +249,12 @@ def efficientnet(width_coefficient=None, depth_coefficient=None,
207249
dropout_rate=dropout_rate,
208250
drop_connect_rate=drop_connect_rate,
209251
# data_format='channels_last', # removed, this is always true in PyTorch
210-
num_classes=1000,
252+
num_classes=num_classes,
211253
width_coefficient=width_coefficient,
212254
depth_coefficient=depth_coefficient,
213255
depth_divisor=8,
214-
min_depth=None
256+
min_depth=None,
257+
image_size=image_size,
215258
)
216259

217260
return blocks_args, global_params
@@ -220,9 +263,10 @@ def efficientnet(width_coefficient=None, depth_coefficient=None,
220263
def get_model_params(model_name, override_params):
221264
""" Get the block args and global params for a given model """
222265
if model_name.startswith('efficientnet'):
223-
w, d, _, p = efficientnet_params(model_name)
266+
w, d, s, p = efficientnet_params(model_name)
224267
# note: all models have drop connect rate = 0.2
225-
blocks_args, global_params = efficientnet(width_coefficient=w, depth_coefficient=d, dropout_rate=p)
268+
blocks_args, global_params = efficientnet(
269+
width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
226270
else:
227271
raise NotImplementedError('model name is not pre-defined: %s' % model_name)
228272
if override_params:
@@ -240,8 +284,14 @@ def get_model_params(model_name, override_params):
240284
'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet-b5-586e6cc6.pth',
241285
}
242286

243-
def load_pretrained_weights(model, model_name):
287+
def load_pretrained_weights(model, model_name, load_fc=True):
244288
""" Loads pretrained weights, and downloads if loading for the first time. """
245289
state_dict = model_zoo.load_url(url_map[model_name])
246-
model.load_state_dict(state_dict)
290+
if load_fc:
291+
model.load_state_dict(state_dict)
292+
else:
293+
state_dict.pop('_fc.weight')
294+
state_dict.pop('_fc.bias')
295+
res = model.load_state_dict(state_dict, strict=False)
296+
assert str(res.missing_keys) == str(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
247297
print('Loaded pretrained weights for {}'.format(model_name))

setup.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
1919
AUTHOR = 'Luke'
2020
REQUIRES_PYTHON = '>=3.5.0'
21-
VERSION = '0.2.0'
21+
VERSION = '0.3.0'
2222

2323
# What packages are required for this module to be executed?
2424
REQUIRED = [
@@ -109,16 +109,13 @@ def run(self):
109109
extras_require=EXTRAS,
110110
include_package_data=True,
111111
license='Apache',
112-
# classifiers=[
113-
# # Trove classifiers
114-
# # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
115-
# 'License :: OSI Approved :: MIT License',
116-
# 'Programming Language :: Python',
117-
# 'Programming Language :: Python :: 3',
118-
# 'Programming Language :: Python :: 3.6',
119-
# 'Programming Language :: Python :: Implementation :: CPython',
120-
# 'Programming Language :: Python :: Implementation :: PyPy'
121-
# ],
112+
classifiers=[
113+
# Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
114+
'License :: OSI Approved :: Apache Software License',
115+
'Programming Language :: Python',
116+
'Programming Language :: Python :: 3',
117+
'Programming Language :: Python :: 3.6',
118+
],
122119
# $ setup.py publish support.
123120
cmdclass={
124121
'upload': UploadCommand,

0 commit comments

Comments
 (0)