Skip to content

Commit 60ac742

Browse files
Added some Activation functions (#175)
* Added Activation Functions * Update activation.jl * Remove Typecast errors * Added Activation function tests * Removed extra space * Removed unwanted typecast * Update Test Functions * Apply suggestions from code review Co-Authored-By: matsueushi <[email protected]> * Update activation.jl * Updated celu and trelu * Update activation.jl Co-authored-by: matsueushi <[email protected]>
1 parent fbea0c5 commit 60ac742

File tree

2 files changed

+90
-11
lines changed

2 files changed

+90
-11
lines changed

src/activation.jl

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
export σ, sigmoid, relu, leakyrelu, relu6, rrelu, elu, gelu, swish, selu, celu, softplus, softsign, logσ,
2-
logsigmoid, logcosh, mish, tanhshrink, softshrink
1+
export σ, sigmoid, hardσ, hardsigmoid, hardtanh, relu, leakyrelu, relu6, rrelu, elu, gelu, swish, selu, celu, softplus, softsign, logσ,
2+
logsigmoid, logcosh, mish, tanhshrink, softshrink, thresholdrelu, trelu, lisht
33

44
"""
55
σ(x) = 1 / (1 + exp(-x))
@@ -17,6 +17,15 @@ const sigmoid = σ
1717
σ(x::ForwardDiff.Dual{T,Float32}) where T = σ_stable(x)
1818
end
1919

20+
"""
21+
hardσ(x, a=0.2) = max(0, min(1.0, a * x + 0.5))
22+
23+
Segment-wise linear approximation of sigmoid
24+
See: [BinaryConnect: Training Deep Neural Networks withbinary weights during propagations](https://arxiv.org/pdf/1511.00363.pdf)
25+
"""
26+
hardσ(x::Real, a=0.2) = oftype(x/1, max(zero(x/1), min(one(x/1), oftype(x/1,a) * x + oftype(x/1,0.5))))
27+
const hardsigmoid = hardσ
28+
2029

2130
"""
2231
logσ(x)
@@ -35,6 +44,15 @@ logσ(x::Real) = -softplus(-x)
3544
const logsigmoid = logσ
3645

3746

47+
"""
48+
hardtanh(x) = max(-1, min(1, x))
49+
50+
Segment-wise linear approximation of tanh. Cheaper and more computational efficient version of tanh.
51+
See: (http://ronan.collobert.org/pub/matos/2004_phdthesis_lip6.pdf)
52+
"""
53+
hardtanh(x::Real) = max(-one(x), min( one(x), x))
54+
55+
3856
"""
3957
relu(x) = max(0, x)
4058
@@ -110,6 +128,16 @@ See [Swish: a Self-Gated Activation Function](https://arxiv.org/pdf/1710.05941.p
110128
"""
111129
swish(x::Real) = x * σ(x)
112130

131+
132+
"""
133+
lisht(x) = x * tanh(x)
134+
135+
Non-Parametric Linearly Scaled Hyperbolic Tangent Activation Function
136+
See [LiSHT](https://arxiv.org/abs/1901.05894)
137+
"""
138+
lisht(x::Real) = x * tanh(x)
139+
140+
113141
"""
114142
selu(x) = λ * (x ≥ 0 ? x : α * (exp(x) - 1))
115143
@@ -132,9 +160,18 @@ end
132160
Continuously Differentiable Exponential Linear Units
133161
See [Continuously Differentiable Exponential Linear Units](https://arxiv.org/pdf/1704.07483.pdf).
134162
"""
135-
function celu(x::Real, α::Real = one(x))
136-
return ifelse(x 0, x / one(x), α * (exp(x/α) - one(x)))
137-
end
163+
celu(x::Real, α::Real = one(x)) = ifelse(x 0, x / one(x), α * (exp(x/α) - one(x)))
164+
165+
166+
"""
167+
trelu(x, theta = 1.0) = x > theta ? x : 0
168+
169+
Threshold Gated Rectified Linear
170+
See [ThresholdRelu](https://arxiv.org/pdf/1402.3337.pdf)
171+
"""
172+
trelu(x::Real,theta = one(x)) = ifelse(x> theta, x, zero(x))
173+
const thresholdrelu = trelu
174+
138175

139176
"""
140177
softsign(x) = x / (1 + |x|)
@@ -184,7 +221,7 @@ See [Softshrink Activation Function](https://www.gabormelli.com/RKB/Softshrink_A
184221
softshrink(x::Real, λ = oftype(x/1, 0.5)) = min(max(zero(x), x - λ), x + λ)
185222

186223
# Provide an informative error message if activation functions are called with an array
187-
for f in (, :σ_stable, :logσ, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :selu, :celu, :softsign, :softplus, :logcosh, :mish, :tanhshrink, :softshrink)
224+
for f in (, :σ_stable, :hardσ, :logσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :lisht, :selu, :celu, :trelu, :softsign, :softplus, :logcosh, :mish, :tanhshrink, :softshrink)
188225
@eval $(f)(x::AbstractArray, args...) =
189226
error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.")
190-
end
227+
end

test/activation.jl

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using NNlib, Test, Zygote
22

3-
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, relu6, rrelu, elu, gelu, celu, swish, selu, softplus, softsign, logcosh, mish, tanhshrink, softshrink];
3+
ACTIVATION_FUNCTIONS = [σ, hardσ, hardtanh, relu, leakyrelu, relu6, rrelu, elu, gelu, celu, swish, lisht, selu, trelu, softplus, softsign, logcosh, mish, tanhshrink, softshrink];
44

55
function test_value_float_precision_preserving(a)
66
@testset "$(a): " begin
@@ -37,53 +37,65 @@ end
3737

3838
@testset "Activation Functions" begin
3939
@test σ(0.0) == 0.5
40+
@test hardσ(0.0) == 0.5
41+
@test hardtanh(0.0) == 0.0
4042
@test relu(0.0) == 0.0
4143
@test leakyrelu(0.0) == 0.0
4244
@test relu6(0.0) == 0.0
4345
@test rrelu(0.0) == 0.0
4446
@test elu(0.0) == 0.0
4547
@test gelu(0.0) == 0.0
4648
@test swish(0.0) == 0.0
49+
@test lisht(0.0) == 0.0
4750
@test softplus(0.0) log(2.0)
4851
@test softplus(1e8) 1e8
4952
@test softplus(-1e8) 0.0
5053
@test softsign(0.0) == 0.0
5154
@test selu(0.0) == 0.0
5255
@test celu(0.0) == 0.0
56+
@test trelu(0.0) == 0.0
5357
@test logcosh(0.0) == log(cosh(0.0))
5458
@test mish(0.0) == 0.0
5559
@test tanhshrink(0.0) == 0.0
5660
@test softshrink(0.0) == 0.0
5761

5862
@test σ(1.0) == 1.0 / (1.0 + exp(-1.0))
63+
@test hardσ(1.0) == max(0,min(1,0.2*1.0 + 0.5))
64+
@test hardtanh(1.0) == 1.0
5965
@test relu(1.0) == 1.0
6066
@test leakyrelu(1.0) == 1.0
6167
@test relu6(1.0) == 1.0
6268
@test rrelu(1.0) == 1.0
6369
@test elu(1.0) == 1.0
6470
@test gelu(1.0) == 0.8411919906082768
6571
@test swish(1.0) == 1.0 / (1.0 + exp(-1.0))
72+
@test lisht(1.0) 1.0 * tanh(1.0)
6673
@test softplus(1.0) log(exp(1.0) + 1.0)
6774
@test softsign(1.0) == 0.5
6875
@test selu(1.0) == 1.0507009873554804934193349852946
6976
@test celu(1.0) == 1.0
77+
@test trelu(1.0) == 0.0
7078
@test logcosh(1.0) log(cosh(1.0))
7179
@test mish(1.0) tanh(log(1.0 + exp(1.0)))
7280
@test tanhshrink(1.0) 0.23840584404423515
7381
@test softshrink(1.0) == 0.5
7482

7583
@test σ(-1.0) == 1.0 / (1.0 + exp(1.0))
84+
@test hardσ(-1.0) == max(0,min(1,0.2*-1.0 + 0.5))
85+
@test hardtanh(-1.0) == -1.0
7686
@test relu(-1.0) == 0.0
7787
@test leakyrelu(-1.0) == -0.01
7888
@test relu6(-1.0) == 0.0
7989
@test -1/3.0 <= rrelu(-1.0) <= -1/8.0
8090
@test elu(-1.0) == exp(-1.0) - 1.0
8191
@test gelu(-1.0) == -0.15880800939172324
8292
@test swish(-1.0) == -1.0 / (1.0 + exp(1.0))
93+
@test lisht(-1.0) -1.0 * tanh(-1.0)
8394
@test softplus(-1.0) log(exp(-1.0) + 1.0)
8495
@test softsign(-1.0) == -0.5
8596
@test selu(-1.0) == 1.0507009873554804934193349852946 * 1.6732632423543772848170429916717 * (exp(-1.0) - 1.0)
8697
@test celu(-1.0) == exp(-1.0) - 1
98+
@test trelu(-1.0) == 0.0
8799
@test log(cosh(-1.0)) log(cosh(-1.0))
88100
@test mish(-1.0) -tanh(log(1.0 + exp(-1.0)))
89101
@test tanhshrink(-1.0) -0.23840584404423515
@@ -101,7 +113,7 @@ end
101113
end
102114

103115
@testset "Test Integer64 and Integer32 inputs will force Float64 outputs" begin
104-
test_value_int_input_forces_float64.(filter(x -> (x != relu && x != relu6), ACTIVATION_FUNCTIONS))
116+
test_value_int_input_forces_float64.(filter(x -> (x != relu && x != relu6 && x != hardtanh && x != trelu), ACTIVATION_FUNCTIONS))
105117

106118
@testset "relu: " begin
107119
# relu doesn't have to force floating point outputs
@@ -114,7 +126,18 @@ end
114126
@test typeof(relu6(Int64(1))) == Int64
115127
@test typeof(relu6(Int32(1))) == Int32
116128
end
117-
129+
130+
@testset "hardtanh: " begin
131+
# hardtanh doesn't have to force floating point outputs
132+
@test typeof(hardtanh(Int64(1))) == Int64
133+
@test typeof(hardtanh(Int32(1))) == Int32
134+
end
135+
136+
@testset "trelu: " begin
137+
# trelu doesn't have to force floating point outputs
138+
@test typeof(trelu(Int64(1))) == Int64
139+
@test typeof(trelu(Int32(1))) == Int32
140+
end
118141
end
119142

120143
@testset "Float gradient inference" begin
@@ -202,4 +225,23 @@ end
202225
end
203226

204227
@test logcosh(1_000.0) + log(2) == 1_000.0
205-
end
228+
229+
@testset "hardsigmoid" begin
230+
@test hardsigmoid(0.3) == 0.56
231+
@test hardsigmoid(-0.3) == 0.44
232+
@test hardsigmoid(0.1,0.5) == 0.55
233+
for T in [:Float32, :Float64]
234+
@eval @test hardsigmoid.($T[-100_000, 100_000.]) $T[0., 1.]
235+
end
236+
end
237+
238+
@test hardtanh(10.0) == 1.0
239+
@test lisht(2.5) == 2.5*tanh(2.5)
240+
241+
@testset "trelu" begin
242+
@test trelu(0.5) == 0.0
243+
@test trelu(1.0) == 0.0
244+
@test trelu(1.1) == 1.1
245+
@test trelu(0.9,0.5) == 0.9
246+
end
247+
end

0 commit comments

Comments
 (0)