Skip to content

Commit 4b4c3a3

Browse files
committed
implment gelu
test for gelu fix indent
1 parent 085adb7 commit 4b4c3a3

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

src/NNlib.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module NNlib
22

33
using Requires, Libdl
44

5-
export σ, sigmoid, relu, leakyrelu, elu, swish, selu, softplus, softsign, logσ, logsigmoid,
5+
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ, logsigmoid,
66
softmax, logsoftmax, maxpool, meanpool
77

88
include("numeric.jl")

src/activation.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,20 @@ You can also specify the coefficient explicitly, e.g. `elu(x, 1)`.
6666
"""
6767
elu(x, α = one(x)) = ifelse(x 0, x/1, α * (exp(x) - one(x)))
6868

69+
"""
70+
gelu(x) = 0.5x*(1 + tanh(√(2/π)*(x + 0.044715x^3)))
71+
72+
[Gaussian Error Linear Unit](https://arxiv.org/pdf/1606.08415.pdf)
73+
activation function.
74+
"""
75+
function gelu(x)
76+
λ = oftype(x/1, (2/π))
77+
α = oftype(x/1, 0.044715)
78+
h = oftype(x/1, 0.5)
79+
h * x * (one(x) + tanh* (x + α * x^3)))
80+
end
81+
82+
6983
"""
7084
swish(x) = x * σ(x)
7185

test/activation.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, swish, selu, softplus, softsign];
1+
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign];
22

33
function test_value_float_precision_preserving(a)
44
@testset "$(a): " begin
@@ -42,6 +42,7 @@ end
4242
@test relu(0.0) == 0.0
4343
@test leakyrelu(0.0) == 0.0
4444
@test elu(0.0) == 0.0
45+
@test gelu(0.0) == 0.0
4546
@test swish(0.0) == 0.0
4647
@test softplus(0.0) log(2.0)
4748
@test softsign(0.0) == 0.0
@@ -51,6 +52,7 @@ end
5152
@test relu(1.0) == 1.0
5253
@test leakyrelu(1.0) == 1.0
5354
@test elu(1.0) == 1.0
55+
@test gelu(1.0) == 0.8411919906082768
5456
@test swish(1.0) == 1.0 / (1.0 + exp(-1.0))
5557
@test softplus(1.0) log(exp(1.0) + 1.0)
5658
@test softsign(1.0) == 0.5
@@ -60,6 +62,7 @@ end
6062
@test relu(-1.0) == 0.0
6163
@test leakyrelu(-1.0) == -0.01
6264
@test elu(-1.0) == exp(-1.0) - 1.0
65+
@test gelu(-1.0) == -0.15880800939172324
6366
@test swish(-1.0) == -1.0 / (1.0 + exp(1.0))
6467
@test softplus(-1.0) log(exp(-1.0) + 1.0)
6568
@test softsign(-1.0) == -0.5

0 commit comments

Comments
 (0)