Skip to content

Commit 7e4f8a7

Browse files
author
Dmytro Panchenko
committed
Patch memory issues in swish activation
1 parent de40cbf commit 7e4f8a7

File tree

1 file changed

+28
-9
lines changed

1 file changed

+28
-9
lines changed

efficientnet_pytorch/utils.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch.nn import functional as F
1313
from torch.utils import model_zoo
1414

15-
1615
########################################################################
1716
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
1817
########################################################################
@@ -24,21 +23,37 @@
2423
'num_classes', 'width_coefficient', 'depth_coefficient',
2524
'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])
2625

27-
2826
# Parameters for an individual model block
2927
BlockArgs = collections.namedtuple('BlockArgs', [
3028
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
3129
'expand_ratio', 'id_skip', 'stride', 'se_ratio'])
3230

33-
3431
# Change namedtuple defaults
3532
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
3633
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
3734

3835

39-
def relu_fn(x):
40-
""" Swish activation function """
41-
return x * torch.sigmoid(x)
36+
class SwishImplementation(torch.autograd.Function):
37+
@staticmethod
38+
def forward(ctx, i):
39+
result = i * torch.sigmoid(i)
40+
ctx.save_for_backward(i)
41+
return result
42+
43+
@staticmethod
44+
def backward(ctx, grad_output):
45+
i = ctx.saved_variables[0]
46+
sigmoid_i = torch.sigmoid(i)
47+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
48+
49+
50+
class Swish(nn.Module):
51+
@staticmethod
52+
def forward(x):
53+
return SwishImplementation.apply(x)
54+
55+
56+
relu_fn = Swish()
4257

4358

4459
def round_filters(filters, global_params):
@@ -84,11 +99,13 @@ def get_same_padding_conv2d(image_size=None):
8499
else:
85100
return partial(Conv2dStaticSamePadding, image_size=image_size)
86101

102+
87103
class Conv2dDynamicSamePadding(nn.Conv2d):
88104
""" 2D Convolutions like TensorFlow, for a dynamic image size """
105+
89106
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
90107
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
91-
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]]*2
108+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
92109

93110
def forward(self, x):
94111
ih, iw = x.size()[-2:]
@@ -98,12 +115,13 @@ def forward(self, x):
98115
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
99116
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
100117
if pad_h > 0 or pad_w > 0:
101-
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
118+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
102119
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
103120

104121

105122
class Conv2dStaticSamePadding(nn.Conv2d):
106123
""" 2D Convolutions like TensorFlow, for a fixed image size"""
124+
107125
def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
108126
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
109127
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
@@ -128,7 +146,7 @@ def forward(self, x):
128146

129147

130148
class Identity(nn.Module):
131-
def __init__(self,):
149+
def __init__(self, ):
132150
super(Identity, self).__init__()
133151

134152
def forward(self, input):
@@ -286,6 +304,7 @@ def get_model_params(model_name, override_params):
286304
'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth',
287305
}
288306

307+
289308
def load_pretrained_weights(model, model_name, load_fc=True):
290309
""" Loads pretrained weights, and downloads if loading for the first time. """
291310
state_dict = model_zoo.load_url(url_map[model_name])

0 commit comments

Comments
 (0)