Skip to content

Commit 07852a8

Browse files
authored
Update test cases of activation functions (#162)
* Add mish to ACTIVATION_FUNCTIONS * Type stability test for gradients * Resolve gradient error * Removing spaces
1 parent 8ae46c5 commit 07852a8

File tree

3 files changed

+28
-10
lines changed

3 files changed

+28
-10
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ julia = "1"
1616

1717
[extras]
1818
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
19+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1920

2021
[targets]
21-
test = ["Test"]
22+
test = ["Test", "Zygote"]

src/activation.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ elu(x, α = one(x)) = ifelse(x ≥ 0, x/one(x), α * (exp(x) - one(x)))
7272
activation function.
7373
"""
7474
function gelu(x::Real)
75-
λ = oftype(x/1, (2/π))
75+
p = oftype(x/1, π)
76+
λ = oftype(x/1, (2/p))
7677
α = oftype(x/1, 0.044715)
7778
h = oftype(x/1, 0.5)
7879
h * x * (one(x) + tanh* (x + α * x^3)))
@@ -126,12 +127,6 @@ Return `log(cosh(x))` which is computed in a numerically stable way.
126127
"""
127128
logcosh(x::T) where T = x + softplus(-2x) - log(convert(T, 2))
128129

129-
# Provide an informative error message if activation functions are called with an array
130-
for f in (, :σ_stable, :logσ, :relu, :leakyrelu, :elu, :gelu, :swish, :selu, :softsign, :softplus, :logcosh)
131-
@eval $(f)(x::AbstractArray, args...) =
132-
error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.")
133-
end
134-
135130

136131
"""
137132
mish(x) = x * tanh(softplus(x))
@@ -140,3 +135,10 @@ Self Regularized Non-Monotonic Neural Activation Function
140135
See [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681).
141136
"""
142137
mish(x::Real) = x * tanh(softplus(x))
138+
139+
140+
# Provide an informative error message if activation functions are called with an array
141+
for f in (, :σ_stable, :logσ, :relu, :leakyrelu, :elu, :gelu, :swish, :selu, :softsign, :softplus, :logcosh, :mish)
142+
@eval $(f)(x::AbstractArray, args...) =
143+
error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.")
144+
end

test/activation.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
using NNlib, Test
1+
using NNlib, Test, Zygote
22

3-
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logcosh];
3+
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logcosh, mish];
44

55
function test_value_float_precision_preserving(a)
66
@testset "$(a): " begin
@@ -24,6 +24,17 @@ function test_value_int_input_forces_float64(a)
2424
end
2525
end
2626

27+
function test_gradient_float_precision_preserving(a)
28+
@testset "$(a): " begin
29+
for T in [Float32, Float64]
30+
for val in [-10, -1, 0, 1, 10]
31+
val = @inferred a'(T(val))
32+
@test typeof(val) == T
33+
end
34+
end
35+
end
36+
end
37+
2738
@testset "Activation Functions" begin
2839
@test σ(0.0) == 0.5
2940
@test relu(0.0) == 0.0
@@ -83,6 +94,10 @@ end
8394
@test typeof(relu(Int32(1))) == Int32
8495
end
8596
end
97+
98+
@testset "Float gradient inference" begin
99+
test_gradient_float_precision_preserving.(ACTIVATION_FUNCTIONS)
100+
end
86101

87102
@testset "softmax" begin
88103
xs = rand(5,5)

0 commit comments

Comments
 (0)