Skip to content

Commit 7ac6db4

Browse files
committed
Missed activations.py
1 parent 506df0e commit 7ac6db4

File tree

1 file changed

+180
-0
lines changed

1 file changed

+180
-0
lines changed

timm/models/activations.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import torch
2+
from torch import nn as nn
3+
from torch.nn import functional as F
4+
5+
6+
_USE_MEM_EFFICIENT_ISH = True
7+
if _USE_MEM_EFFICIENT_ISH:
8+
# This version reduces memory overhead of Swish during training by
9+
# recomputing torch.sigmoid(x) in backward instead of saving it.
10+
class SwishAutoFn(torch.autograd.Function):
11+
"""Swish - Described in: https://arxiv.org/abs/1710.05941
12+
Memory efficient variant from:
13+
https://medium.com/the-artificial-impostor/more-memory-efficient-swish-activation-function-e07c22c12a76
14+
"""
15+
@staticmethod
16+
def forward(ctx, x):
17+
result = x.mul(torch.sigmoid(x))
18+
ctx.save_for_backward(x)
19+
return result
20+
21+
@staticmethod
22+
def backward(ctx, grad_output):
23+
x = ctx.saved_variables[0]
24+
sigmoid_x = torch.sigmoid(x)
25+
return grad_output.mul(sigmoid_x * (1 + x * (1 - sigmoid_x)))
26+
27+
def swish(x, inplace=False):
28+
# inplace ignored
29+
return SwishAutoFn.apply(x)
30+
31+
32+
class MishAutoFn(torch.autograd.Function):
33+
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
34+
Experimental memory-efficient variant
35+
"""
36+
37+
@staticmethod
38+
def forward(ctx, x):
39+
ctx.save_for_backward(x)
40+
y = x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
41+
return y
42+
43+
@staticmethod
44+
def backward(ctx, grad_output):
45+
x = ctx.saved_variables[0]
46+
x_sigmoid = torch.sigmoid(x)
47+
x_tanh_sp = F.softplus(x).tanh()
48+
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
49+
50+
def mish(x, inplace=False):
51+
# inplace ignored
52+
return MishAutoFn.apply(x)
53+
54+
55+
class WishAutoFn(torch.autograd.Function):
56+
"""Wish: My own mistaken creation while fiddling with Mish. Did well in some experiments.
57+
Experimental memory-efficient variant
58+
"""
59+
60+
@staticmethod
61+
def forward(ctx, x):
62+
ctx.save_for_backward(x)
63+
y = x.mul(torch.tanh(torch.exp(x)))
64+
return y
65+
66+
@staticmethod
67+
def backward(ctx, grad_output):
68+
x = ctx.saved_variables[0]
69+
x_exp = x.exp()
70+
x_tanh_exp = x_exp.tanh()
71+
return grad_output.mul(x_tanh_exp + x * x_exp * (1 - x_tanh_exp * x_tanh_exp))
72+
73+
def wish(x, inplace=False):
74+
# inplace ignored
75+
return WishAutoFn.apply(x)
76+
else:
77+
def swish(x, inplace=False):
78+
"""Swish - Described in: https://arxiv.org/abs/1710.05941
79+
"""
80+
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
81+
82+
83+
def mish(x, inplace=False):
84+
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
85+
"""
86+
inner = F.softplus(x).tanh()
87+
return x.mul_(inner) if inplace else x.mul(inner)
88+
89+
90+
def wish(x, inplace=False):
91+
"""Wish: My own mistaken creation while fiddling with Mish. Did well in some experiments.
92+
"""
93+
inner = x.exp().tanh()
94+
return x.mul_(inner) if inplace else x.mul(inner)
95+
96+
97+
class Swish(nn.Module):
98+
def __init__(self, inplace=False):
99+
super(Swish, self).__init__()
100+
self.inplace = inplace
101+
102+
def forward(self, x):
103+
return swish(x, self.inplace)
104+
105+
106+
class Mish(nn.Module):
107+
def __init__(self, inplace=False):
108+
super(Mish, self).__init__()
109+
self.inplace = inplace
110+
111+
def forward(self, x):
112+
return mish(x, self.inplace)
113+
114+
115+
class Wish(nn.Module):
116+
def __init__(self, inplace=False):
117+
super(Wish, self).__init__()
118+
self.inplace = inplace
119+
120+
def forward(self, x):
121+
return wish(x, self.inplace)
122+
123+
124+
def sigmoid(x, inplace=False):
125+
return x.sigmoid_() if inplace else x.sigmoid()
126+
127+
128+
# PyTorch has this, but not with a consistent inplace argmument interface
129+
class Sigmoid(nn.Module):
130+
def __init__(self, inplace=False):
131+
super(Sigmoid, self).__init__()
132+
self.inplace = inplace
133+
134+
def forward(self, x):
135+
return x.sigmoid_() if self.inplace else x.sigmoid()
136+
137+
138+
def tanh(x, inplace=False):
139+
return x.tanh_() if inplace else x.tanh()
140+
141+
142+
# PyTorch has this, but not with a consistent inplace argmument interface
143+
class Tanh(nn.Module):
144+
def __init__(self, inplace=False):
145+
super(Tanh, self).__init__()
146+
self.inplace = inplace
147+
148+
def forward(self, x):
149+
return x.tanh_() if self.inplace else x.tanh()
150+
151+
152+
def hard_swish(x, inplace=False):
153+
inner = F.relu6(x + 3.).div_(6.)
154+
return x.mul_(inner) if inplace else x.mul(inner)
155+
156+
157+
class HardSwish(nn.Module):
158+
def __init__(self, inplace=False):
159+
super(HardSwish, self).__init__()
160+
self.inplace = inplace
161+
162+
def forward(self, x):
163+
return hard_swish(x, self.inplace)
164+
165+
166+
def hard_sigmoid(x, inplace=False):
167+
if inplace:
168+
return x.add_(3.).clamp_(0., 6.).div_(6.)
169+
else:
170+
return F.relu6(x + 3.) / 6.
171+
172+
173+
class HardSigmoid(nn.Module):
174+
def __init__(self, inplace=False):
175+
super(HardSigmoid, self).__init__()
176+
self.inplace = inplace
177+
178+
def forward(self, x):
179+
return hard_sigmoid(x, self.inplace)
180+

0 commit comments

Comments
 (0)