Skip to content

Commit 61066be

Browse files
authored
Merge pull request #90 from lukemelas/swish
Add memory-efficient and export-friendly swish activation functions
2 parents 4268864 + 7f828f6 commit 61066be

File tree

5 files changed

+33
-15
lines changed

5 files changed

+33
-15
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
# EfficientNet PyTorch
22

3-
### Update (October 12, 2019)
3+
### Update (October 15, 2019)
44

5-
This update makes the Swish activation function more memory-efficient. It also addresses pull requests #72, #73, #85, and #86. Thanks to the authors of all the pull requests!
5+
This update allows you to choose whether to use a memory-efficient Swish activation. The memory-efficient version is chosen by default, but it cannot be used when exporting using PyTorch JIT. For this purpose, we have also included a standard (export-friendly) swish activation function. To switch to the export-friendly version, simply call `model.set_swish(memory_efficient=False)` after loading your desired model. This update addresses issues [#88](https://github.com/lukemelas/EfficientNet-PyTorch/pull/88) and [#89](https://github.com/lukemelas/EfficientNet-PyTorch/pull/89).
6+
7+
#### Update (October 12, 2019)
8+
9+
This update makes the Swish activation function more memory-efficient. It also addresses pull requests [#72](https://github.com/lukemelas/EfficientNet-PyTorch/pull/72), [#73](https://github.com/lukemelas/EfficientNet-PyTorch/pull/73), [#85](https://github.com/lukemelas/EfficientNet-PyTorch/pull/85), and [#86](https://github.com/lukemelas/EfficientNet-PyTorch/pull/86). Thanks to the authors of all the pull requests!
610

711
### Update (July 31, 2019)
812

efficientnet_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.5.0"
1+
__version__ = "0.5.1"
22
from .model import EfficientNet
33
from .utils import (
44
GlobalParams,

efficientnet_pytorch/model.py

100644100755
Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from torch.nn import functional as F
44

55
from .utils import (
6-
relu_fn,
76
round_filters,
87
round_repeats,
98
drop_connect,
109
get_same_padding_conv2d,
1110
get_model_params,
1211
efficientnet_params,
1312
load_pretrained_weights,
13+
Swish,
14+
MemoryEfficientSwish,
1415
)
1516

1617
class MBConvBlock(nn.Module):
@@ -61,6 +62,7 @@ def __init__(self, block_args, global_params):
6162
final_oup = self._block_args.output_filters
6263
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
6364
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
65+
self._swish = MemoryEfficientSwish()
6466

6567
def forward(self, inputs, drop_connect_rate=None):
6668
"""
@@ -72,13 +74,13 @@ def forward(self, inputs, drop_connect_rate=None):
7274
# Expansion and Depthwise Convolution
7375
x = inputs
7476
if self._block_args.expand_ratio != 1:
75-
x = relu_fn(self._bn0(self._expand_conv(inputs)))
76-
x = relu_fn(self._bn1(self._depthwise_conv(x)))
77+
x = self._swish(self._bn0(self._expand_conv(inputs)))
78+
x = self._swish(self._bn1(self._depthwise_conv(x)))
7779

7880
# Squeeze and Excitation
7981
if self.has_se:
8082
x_squeezed = F.adaptive_avg_pool2d(x, 1)
81-
x_squeezed = self._se_expand(relu_fn(self._se_reduce(x_squeezed)))
83+
x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
8284
x = torch.sigmoid(x_squeezed) * x
8385

8486
x = self._bn2(self._project_conv(x))
@@ -91,6 +93,10 @@ def forward(self, inputs, drop_connect_rate=None):
9193
x = x + inputs # skip connection
9294
return x
9395

96+
def set_swish(self, memory_efficient=True):
97+
"""Sets swish function as memory efficient (for training) or standard (for export)"""
98+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
99+
94100

95101
class EfficientNet(nn.Module):
96102
"""
@@ -153,12 +159,20 @@ def __init__(self, blocks_args=None, global_params=None):
153159
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
154160
self._dropout = nn.Dropout(self._global_params.dropout_rate)
155161
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
162+
self._swish = MemoryEfficientSwish()
163+
164+
def set_swish(self, memory_efficient=True):
165+
"""Sets swish function as memory efficient (for training) or standard (for export)"""
166+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
167+
for block in self._blocks:
168+
block.set_swish(memory_efficient)
169+
156170

157171
def extract_features(self, inputs):
158172
""" Returns output of the final convolution layer """
159173

160174
# Stem
161-
x = relu_fn(self._bn0(self._conv_stem(inputs)))
175+
x = self._swish(self._bn0(self._conv_stem(inputs)))
162176

163177
# Blocks
164178
for idx, block in enumerate(self._blocks):
@@ -168,7 +182,7 @@ def extract_features(self, inputs):
168182
x = block(x, drop_connect_rate=drop_connect_rate)
169183

170184
# Head
171-
x = relu_fn(self._bn1(self._conv_head(x)))
185+
x = self._swish(self._bn1(self._conv_head(x)))
172186

173187
return x
174188

efficientnet_pytorch/utils.py

100644100755
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ def backward(ctx, grad_output):
4747
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
4848

4949

50-
class Swish(nn.Module):
51-
@staticmethod
52-
def forward(x):
50+
class MemoryEfficientSwish(nn.Module):
51+
def forward(self, x):
5352
return SwishImplementation.apply(x)
5453

55-
56-
relu_fn = Swish()
54+
class Swish(nn.Module):
55+
def forward(self, x):
56+
return x * torch.sigmoid(x)
5757

5858

5959
def round_filters(filters, global_params):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
1919
AUTHOR = 'Luke'
2020
REQUIRES_PYTHON = '>=3.5.0'
21-
VERSION = '0.5.0'
21+
VERSION = '0.5.1'
2222

2323
# What packages are required for this module to be executed?
2424
REQUIRED = [

0 commit comments

Comments
 (0)