Skip to content

Commit c8b7661

Browse files
Merge pull request #198 from cossio/sigmoid
fix sigmoid
2 parents c835047 + f003632 commit c8b7661

File tree

2 files changed

+13
-17
lines changed

2 files changed

+13
-17
lines changed

src/activation.jl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ export σ, sigmoid, hardσ, hardsigmoid, hardtanh, relu, leakyrelu, relu6, rrelu
22
logsigmoid, logcosh, mish, tanhshrink, softshrink, thresholdrelu, trelu, lisht
33

44
## Activation functions
5-
#
5+
#
66
# Some of activation functions have its wrapper function for GPU in CuArrays.jl.
77
# https://github.com/JuliaGPU/CuArrays.jl/issues/614
88

@@ -12,15 +12,11 @@ export σ, sigmoid, hardσ, hardsigmoid, hardtanh, relu, leakyrelu, relu6, rrelu
1212
Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation
1313
function.
1414
"""
15-
σ(x::Real) = one(x) / (one(x) + exp(-x))
16-
const sigmoid = σ
17-
18-
# ForwardDiff numerical stability hack
19-
σ_stable(x::Real) = ifelse(x < -80, zero(x), one(x) / (one(x) + exp(-x)))
20-
σ(x::Float32) = σ_stable(x)
21-
@init @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
22-
σ(x::ForwardDiff.Dual{T,Float32}) where T = σ_stable(x)
15+
function σ(x::Real)
16+
t = exp(-abs(x))
17+
ifelse(x 0, inv(one(t) + t), t / (one(t) + t))
2318
end
19+
const sigmoid = σ
2420

2521
"""
2622
hardσ(x, a=0.2) = max(0, min(1.0, a * x + 0.5))
@@ -159,17 +155,17 @@ function selu(x::Real)
159155
end
160156

161157
"""
162-
celu(x, α=1) =
158+
celu(x, α=1) =
163159
(x ≥ 0 ? x : α * (exp(x/α) - 1))
164160
165161
Continuously Differentiable Exponential Linear Units
166162
See [Continuously Differentiable Exponential Linear Units](https://arxiv.org/pdf/1704.07483.pdf).
167163
"""
168-
celu(x::Real, α::Real = one(x)) = ifelse(x 0, x / one(x), α * (exp(x/α) - one(x)))
164+
celu(x::Real, α::Real = one(x)) = ifelse(x 0, x / one(x), α * (exp(x/α) - one(x)))
169165

170166

171167
"""
172-
trelu(x, theta = 1.0) = x > theta ? x : 0
168+
trelu(x, theta = 1.0) = x > theta ? x : 0
173169
174170
Threshold Gated Rectified Linear.
175171
See [ThresholdRelu](https://arxiv.org/pdf/1402.3337.pdf)
@@ -218,15 +214,15 @@ See [Tanhshrink Activation Function](https://www.gabormelli.com/RKB/Tanhshrink_A
218214
tanhshrink(x::Real) = x - tanh(x)
219215

220216
"""
221-
softshrink(x, λ=0.5) =
217+
softshrink(x, λ=0.5) =
222218
(x ≥ λ ? x - λ : (-λ ≥ x ? x + λ : 0))
223219
224220
See [Softshrink Activation Function](https://www.gabormelli.com/RKB/Softshrink_Activation_Function).
225221
"""
226222
softshrink(x::Real, λ = oftype(x/1, 0.5)) = min(max(zero(x), x - λ), x + λ)
227223

228224
# Provide an informative error message if activation functions are called with an array
229-
for f in (, :σ_stable, :hardσ, :logσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :lisht, :selu, :celu, :trelu, :softsign, :softplus, :logcosh, :mish, :tanhshrink, :softshrink)
225+
for f in (, :hardσ, :logσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :lisht, :selu, :celu, :trelu, :softsign, :softplus, :logcosh, :mish, :tanhshrink, :softshrink)
230226
@eval $(f)(x::AbstractArray, args...) =
231227
error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.")
232228
end

test/activation.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ end
6868
@test rrelu(1.0) == 1.0
6969
@test elu(1.0) == 1.0
7070
@test gelu(1.0) == 0.8411919906082768
71-
@test swish(1.0) == 1.0 / (1.0 + exp(-1.0))
71+
@test swish(1.0) == σ(1.0)
7272
@test lisht(1.0) 1.0 * tanh(1.0)
7373
@test softplus(1.0) log(exp(1.0) + 1.0)
7474
@test softsign(1.0) == 0.5
@@ -80,7 +80,7 @@ end
8080
@test tanhshrink(1.0) 0.23840584404423515
8181
@test softshrink(1.0) == 0.5
8282

83-
@test σ(-1.0) == 1.0 / (1.0 + exp(1.0))
83+
@test σ(-1.0) == exp(-1.0) / (1.0 + exp(-1.0))
8484
@test hardσ(-1.0) == max(0,min(1,0.2*-1.0 + 0.5))
8585
@test hardtanh(-1.0) == -1.0
8686
@test relu(-1.0) == 0.0
@@ -89,7 +89,7 @@ end
8989
@test -1/3.0 <= rrelu(-1.0) <= -1/8.0
9090
@test elu(-1.0) == exp(-1.0) - 1.0
9191
@test gelu(-1.0) == -0.15880800939172324
92-
@test swish(-1.0) == -1.0 / (1.0 + exp(1.0))
92+
@test swish(-1.0) == -σ(-1.0)
9393
@test lisht(-1.0) -1.0 * tanh(-1.0)
9494
@test softplus(-1.0) log(exp(-1.0) + 1.0)
9595
@test softsign(-1.0) == -0.5

0 commit comments

Comments
 (0)