@@ -2,7 +2,7 @@ export σ, sigmoid, hardσ, hardsigmoid, hardtanh, relu, leakyrelu, relu6, rrelu
2
2
logsigmoid, logcosh, mish, tanhshrink, softshrink, thresholdrelu, trelu, lisht
3
3
4
4
# # Activation functions
5
- #
5
+ #
6
6
# Some of activation functions have its wrapper function for GPU in CuArrays.jl.
7
7
# https://github.com/JuliaGPU/CuArrays.jl/issues/614
8
8
@@ -12,15 +12,11 @@ export σ, sigmoid, hardσ, hardsigmoid, hardtanh, relu, leakyrelu, relu6, rrelu
12
12
Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation
13
13
function.
14
14
"""
15
- σ (x:: Real ) = one (x) / (one (x) + exp (- x))
16
- const sigmoid = σ
17
-
18
- # ForwardDiff numerical stability hack
19
- σ_stable (x:: Real ) = ifelse (x < - 80 , zero (x), one (x) / (one (x) + exp (- x)))
20
- σ (x:: Float32 ) = σ_stable (x)
21
- @init @require ForwardDiff = " f6369f11-7733-5829-9624-2563aa707210" begin
22
- σ (x:: ForwardDiff.Dual{T,Float32} ) where T = σ_stable (x)
15
+ function σ (x:: Real )
16
+ t = exp (- abs (x))
17
+ ifelse (x ≥ 0 , inv (one (t) + t), t / (one (t) + t))
23
18
end
19
+ const sigmoid = σ
24
20
25
21
"""
26
22
hardσ(x, a=0.2) = max(0, min(1.0, a * x + 0.5))
@@ -159,17 +155,17 @@ function selu(x::Real)
159
155
end
160
156
161
157
"""
162
- celu(x, α=1) =
158
+ celu(x, α=1) =
163
159
(x ≥ 0 ? x : α * (exp(x/α) - 1))
164
160
165
161
Continuously Differentiable Exponential Linear Units
166
162
See [Continuously Differentiable Exponential Linear Units](https://arxiv.org/pdf/1704.07483.pdf).
167
163
"""
168
- celu (x:: Real , α:: Real = one (x)) = ifelse (x ≥ 0 , x / one (x), α * (exp (x/ α) - one (x)))
164
+ celu (x:: Real , α:: Real = one (x)) = ifelse (x ≥ 0 , x / one (x), α * (exp (x/ α) - one (x)))
169
165
170
166
171
167
"""
172
- trelu(x, theta = 1.0) = x > theta ? x : 0
168
+ trelu(x, theta = 1.0) = x > theta ? x : 0
173
169
174
170
Threshold Gated Rectified Linear.
175
171
See [ThresholdRelu](https://arxiv.org/pdf/1402.3337.pdf)
@@ -218,15 +214,15 @@ See [Tanhshrink Activation Function](https://www.gabormelli.com/RKB/Tanhshrink_A
218
214
tanhshrink (x:: Real ) = x - tanh (x)
219
215
220
216
"""
221
- softshrink(x, λ=0.5) =
217
+ softshrink(x, λ=0.5) =
222
218
(x ≥ λ ? x - λ : (-λ ≥ x ? x + λ : 0))
223
219
224
220
See [Softshrink Activation Function](https://www.gabormelli.com/RKB/Softshrink_Activation_Function).
225
221
"""
226
222
softshrink (x:: Real , λ = oftype (x/ 1 , 0.5 )) = min (max (zero (x), x - λ), x + λ)
227
223
228
224
# Provide an informative error message if activation functions are called with an array
229
- for f in (:σ , :σ_stable , : hardσ , :logσ , :hardtanh , :relu , :leakyrelu , :relu6 , :rrelu , :elu , :gelu , :swish , :lisht , :selu , :celu , :trelu , :softsign , :softplus , :logcosh , :mish , :tanhshrink , :softshrink )
225
+ for f in (:σ , :hardσ , :logσ , :hardtanh , :relu , :leakyrelu , :relu6 , :rrelu , :elu , :gelu , :swish , :lisht , :selu , :celu , :trelu , :softsign , :softplus , :logcosh , :mish , :tanhshrink , :softshrink )
230
226
@eval $ (f)(x:: AbstractArray , args... ) =
231
227
error (" Use broadcasting (`" , $ (string (f)), " .(x)`) to apply activation functions to arrays." )
232
228
end
0 commit comments