Skip to content

Commit c335aa5

Browse files
authored
Activation functions added (#168)
* some activation functions added * tests added * celu corrected * softshrink corrected * relu6 test corrected * 6 in relu6 typecasted
1 parent 3dc371d commit c335aa5

File tree

2 files changed

+97
-16
lines changed

2 files changed

+97
-16
lines changed

src/activation.jl

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ,
2-
logsigmoid, logcosh, mish
1+
export σ, sigmoid, relu, leakyrelu, relu6, rrelu, elu, gelu, swish, selu, celu, softplus, softsign, logσ,
2+
logsigmoid, logcosh, mish, tanhshrink, softshrink
33

44
"""
55
σ(x) = 1 / (1 + exp(-x))
@@ -13,7 +13,7 @@ const sigmoid = σ
1313
# ForwardDiff numerical stability hack
1414
σ_stable(x::Real) = ifelse(x < -80, zero(x), one(x) / (one(x) + exp(-x)))
1515
σ(x::Float32) = σ_stable(x)
16-
@init @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
16+
@init @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
1717
σ(x::ForwardDiff.Dual{T,Float32}) where T = σ_stable(x)
1818
end
1919

@@ -51,8 +51,29 @@ Leaky [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_ne
5151
activation function.
5252
You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`.
5353
"""
54-
leakyrelu(x::Real, a = oftype(x/1, 0.01)) = max(a*x, x/one(x))
54+
leakyrelu(x::Real, a = oftype(x / 1, 0.01)) = max(a * x, x / one(x))
5555

56+
"""
57+
relu6(x) = min(max(0, x),6)
58+
59+
[Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
60+
activation function.
61+
"""
62+
relu6(x::Real) = min(relu(x), one(x)*oftype(x, 6))
63+
64+
"""
65+
rrelu(x) = max(ax, x)
66+
67+
a = randomly sampled from uniform distribution U(l,u)
68+
69+
Randomized Leaky [Rectified Linear Unit](https://arxiv.org/pdf/1505.00853.pdf)
70+
activation function.
71+
You can also specify the bound explicitly, e.g. `rrelu(x, 0.0, 1.0)`.
72+
"""
73+
function rrelu(x::Real, l::Real = 1 / 8.0, u::Real = 1 / 3.0)
74+
a = oftype(x /1, (u - l) * rand() + l)
75+
return leakyrelu(x, a)
76+
end
5677

5778
"""
5879
elu(x, α = 1) =
@@ -62,7 +83,7 @@ Exponential Linear Unit activation function.
6283
See [Fast and Accurate Deep Network Learning by Exponential Linear Units](https://arxiv.org/abs/1511.07289).
6384
You can also specify the coefficient explicitly, e.g. `elu(x, 1)`.
6485
"""
65-
elu(x, α = one(x)) = ifelse(x 0, x/one(x), α * (exp(x) - one(x)))
86+
elu(x, α = one(x)) = ifelse(x 0, x / one(x), α * (exp(x) - one(x)))
6687

6788

6889
"""
@@ -72,10 +93,10 @@ elu(x, α = one(x)) = ifelse(x ≥ 0, x/one(x), α * (exp(x) - one(x)))
7293
activation function.
7394
"""
7495
function gelu(x::Real)
75-
p = oftype(x/1, π)
76-
λ = oftype(x/1, (2/p))
77-
α = oftype(x/1, 0.044715)
78-
h = oftype(x/1, 0.5)
96+
p = oftype(x / 1, π)
97+
λ = oftype(x / 1, (2 / p))
98+
α = oftype(x / 1, 0.044715)
99+
h = oftype(x / 1, 0.5)
79100
h * x * (one(x) + tanh* (x + α * x^3)))
80101
end
81102

@@ -98,11 +119,20 @@ Scaled exponential linear units.
98119
See [Self-Normalizing Neural Networks](https://arxiv.org/pdf/1706.02515.pdf).
99120
"""
100121
function selu(x::Real)
101-
λ = oftype(x/1, 1.0507009873554804934193349852946)
102-
α = oftype(x/1, 1.6732632423543772848170429916717)
103-
λ * ifelse(x > 0, x/one(x), α * (exp(x) - one(x)))
122+
λ = oftype(x / 1, 1.0507009873554804934193349852946)
123+
α = oftype(x / 1, 1.6732632423543772848170429916717)
124+
λ * ifelse(x > 0, x / one(x), α * (exp(x) - one(x)))
104125
end
105126

127+
"""
128+
celu(x) = (x ≥ 0 ? x : α * (exp(x/α) - 1))
129+
130+
Continuously Differentiable Exponential Linear Units
131+
See [Continuously Differentiable Exponential Linear Units](https://arxiv.org/pdf/1704.07483.pdf).
132+
"""
133+
function celu(x::Real, α::Real = one(x))
134+
return ifelse(x 0, x / one(x), α * (exp(x/α) - one(x)))
135+
end
106136

107137
"""
108138
softsign(x) = x / (1 + |x|)
@@ -136,9 +166,22 @@ See [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://
136166
"""
137167
mish(x::Real) = x * tanh(softplus(x))
138168

169+
"""
170+
tanhshrink(x) = x - tanh(x)
171+
172+
See [Tanhshrink Activation Function](https://www.gabormelli.com/RKB/Tanhshrink_Activation_Function)
173+
"""
174+
tanhshrink(x::Real) = x - tanh(x)
175+
176+
"""
177+
softshrink = (x ≥ λ ? x-λ : (-λ ≥ x ? x+λ : 0))
178+
179+
See [Softshrink Activation Function](https://www.gabormelli.com/RKB/Softshrink_Activation_Function)
180+
"""
181+
softshrink(x::Real, λ = oftype(x/1, 0.5)) = min(max(zero(x), x - λ), x + λ)
139182

140183
# Provide an informative error message if activation functions are called with an array
141-
for f in (, :σ_stable, :logσ, :relu, :leakyrelu, :elu, :gelu, :swish, :selu, :softsign, :softplus, :logcosh, :mish)
184+
for f in (, :σ_stable, :logσ, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :selu, :celu, :softsign, :softplus, :logcosh, :mish, :tanhshrink, :softshrink)
142185
@eval $(f)(x::AbstractArray, args...) =
143186
error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.")
144187
end

test/activation.jl

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using NNlib, Test, Zygote
22

3-
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logcosh, mish];
3+
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, relu6, rrelu, elu, gelu, celu, swish, selu, softplus, softsign, logcosh, mish, tanhshrink, softshrink];
44

55
function test_value_float_precision_preserving(a)
66
@testset "$(a): " begin
@@ -39,6 +39,8 @@ end
3939
@test σ(0.0) == 0.5
4040
@test relu(0.0) == 0.0
4141
@test leakyrelu(0.0) == 0.0
42+
@test relu6(0.0) == 0.0
43+
@test rrelu(0.0) == 0.0
4244
@test elu(0.0) == 0.0
4345
@test gelu(0.0) == 0.0
4446
@test swish(0.0) == 0.0
@@ -47,32 +49,45 @@ end
4749
@test softplus(-1e8) 0.0
4850
@test softsign(0.0) == 0.0
4951
@test selu(0.0) == 0.0
52+
@test celu(0.0) == 0.0
5053
@test logcosh(0.0) == log(cosh(0.0))
5154
@test mish(0.0) == 0.0
55+
@test tanhshrink(0.0) == 0.0
56+
@test softshrink(0.0) == 0.0
5257

5358
@test σ(1.0) == 1.0 / (1.0 + exp(-1.0))
5459
@test relu(1.0) == 1.0
5560
@test leakyrelu(1.0) == 1.0
61+
@test relu6(1.0) == 1.0
62+
@test rrelu(1.0) == 1.0
5663
@test elu(1.0) == 1.0
5764
@test gelu(1.0) == 0.8411919906082768
5865
@test swish(1.0) == 1.0 / (1.0 + exp(-1.0))
5966
@test softplus(1.0) log(exp(1.0) + 1.0)
6067
@test softsign(1.0) == 0.5
6168
@test selu(1.0) == 1.0507009873554804934193349852946
69+
@test celu(1.0) == 1.0
6270
@test logcosh(1.0) log(cosh(1.0))
6371
@test mish(1.0) tanh(log(1.0 + exp(1.0)))
72+
@test tanhshrink(1.0) 0.23840584404423515
73+
@test softshrink(1.0) == 0.5
6474

6575
@test σ(-1.0) == 1.0 / (1.0 + exp(1.0))
6676
@test relu(-1.0) == 0.0
6777
@test leakyrelu(-1.0) == -0.01
78+
@test relu6(-1.0) == 0.0
79+
@test -1/3.0 <= rrelu(-1.0) <= -1/8.0
6880
@test elu(-1.0) == exp(-1.0) - 1.0
6981
@test gelu(-1.0) == -0.15880800939172324
7082
@test swish(-1.0) == -1.0 / (1.0 + exp(1.0))
7183
@test softplus(-1.0) log(exp(-1.0) + 1.0)
7284
@test softsign(-1.0) == -0.5
7385
@test selu(-1.0) == 1.0507009873554804934193349852946 * 1.6732632423543772848170429916717 * (exp(-1.0) - 1.0)
86+
@test celu(-1.0) == exp(-1.0) - 1
7487
@test log(cosh(-1.0)) log(cosh(-1.0))
7588
@test mish(-1.0) -tanh(log(1.0 + exp(-1.0)))
89+
@test tanhshrink(-1.0) -0.23840584404423515
90+
@test softshrink(-1.0) == -0.5
7691

7792
@testset "Float inference" begin
7893
test_value_float_precision_preserving.(ACTIVATION_FUNCTIONS)
@@ -86,13 +101,20 @@ end
86101
end
87102

88103
@testset "Test Integer64 and Integer32 inputs will force Float64 outputs" begin
89-
test_value_int_input_forces_float64.(filter(x -> x != relu, ACTIVATION_FUNCTIONS))
104+
test_value_int_input_forces_float64.(filter(x -> (x != relu && x != relu6), ACTIVATION_FUNCTIONS))
90105

91106
@testset "relu: " begin
92107
# relu doesn't have to force floating point outputs
93108
@test typeof(relu(Int64(1))) == Int64
94109
@test typeof(relu(Int32(1))) == Int32
95110
end
111+
112+
@testset "relu6: " begin
113+
# relu6 doesn't have to force floating point outputs
114+
@test typeof(relu6(Int64(1))) == Int64
115+
@test typeof(relu6(Int32(1))) == Int32
116+
end
117+
96118
end
97119

98120
@testset "Float gradient inference" begin
@@ -155,6 +177,22 @@ end
155177
@test leakyrelu( 0.4,0.3) 0.4
156178
@test leakyrelu(-0.4,0.3) -0.12
157179

180+
@test relu6(10.0) == 6.0
181+
@test -0.2 <= rrelu(-0.4,0.25,0.5) <= -0.1
182+
183+
@testset "celu" begin
184+
@test celu(42) == 42
185+
@test celu(42.) == 42.
186+
187+
@test celu(-4, 0.5) 0.5*(exp(-4.0/0.5) - 1)
188+
end
189+
190+
@testset "softshrink" begin
191+
@test softshrink(15., 5.) == 10.
192+
@test softshrink(4., 5.) == 0.
193+
@test softshrink(-15., 5.) == -10.
194+
end
195+
158196
@testset "logsigmoid" begin
159197
xs = randn(10,10)
160198
@test logsigmoid.(xs) log.(sigmoid.(xs))
@@ -164,4 +202,4 @@ end
164202
end
165203

166204
@test logcosh(1_000.0) + log(2) == 1_000.0
167-
end
205+
end

0 commit comments

Comments
 (0)