Skip to content

Commit 71f0127

Browse files
authored
Merge pull request #108 from pshashk/logcosh
Add `logcosh` activation
2 parents 18c2f77 + 52b9c39 commit 71f0127

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

src/activation.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ,
2-
logsigmoid
2+
logsigmoid, logcosh
33

44
"""
55
σ(x) = 1 / (1 + exp(-x))
@@ -118,8 +118,16 @@ See [Deep Sparse Rectifier Neural Networks](http://proceedings.mlr.press/v15/glo
118118
"""
119119
softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x)))
120120

121+
122+
"""
123+
logcosh(x)
124+
125+
Return `log(cosh(x))` which is computed in a numerically stable way.
126+
"""
127+
logcosh(x::T) where T = x + softplus(-2x) - log(convert(T, 2))
128+
121129
# Provide an informative error message if activation functions are called with an array
122-
for f in (, :σ_stable, :logσ, :relu, :leakyrelu, :elu, :gelu, :swish, :selu, :softsign, :softplus)
130+
for f in (, :σ_stable, :logσ, :relu, :leakyrelu, :elu, :gelu, :swish, :selu, :softsign, :softplus, :logcosh)
123131
@eval $(f)(x::AbstractArray, args...) =
124132
error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.")
125133
end

test/activation.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using NNlib, Test
22

3-
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign];
3+
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logcosh];
44

55
function test_value_float_precision_preserving(a)
66
@testset "$(a): " begin
@@ -36,6 +36,7 @@ end
3636
@test softplus(-1e8) 0.0
3737
@test softsign(0.0) == 0.0
3838
@test selu(0.0) == 0.0
39+
@test logcosh(0.0) == log(cosh(0.0))
3940

4041
@test σ(1.0) == 1.0 / (1.0 + exp(-1.0))
4142
@test relu(1.0) == 1.0
@@ -46,6 +47,7 @@ end
4647
@test softplus(1.0) log(exp(1.0) + 1.0)
4748
@test softsign(1.0) == 0.5
4849
@test selu(1.0) == 1.0507009873554804934193349852946
50+
@test logcosh(1.0) log(cosh(1.0))
4951

5052
@test σ(-1.0) == 1.0 / (1.0 + exp(1.0))
5153
@test relu(-1.0) == 0.0
@@ -56,6 +58,7 @@ end
5658
@test softplus(-1.0) log(exp(-1.0) + 1.0)
5759
@test softsign(-1.0) == -0.5
5860
@test selu(-1.0) == 1.0507009873554804934193349852946 * 1.6732632423543772848170429916717 * (exp(-1.0) - 1.0)
61+
@test log(cosh(-1.0)) log(cosh(-1.0))
5962

6063
@testset "Float inference" begin
6164
test_value_float_precision_preserving.(ACTIVATION_FUNCTIONS)
@@ -132,4 +135,6 @@ end
132135
@eval @test logsigmoid.($T[-100_000, 100_000.]) $T[-100_000, 0.]
133136
end
134137
end
138+
139+
@test logcosh(1_000.0) + log(2) == 1_000.0
135140
end

0 commit comments

Comments
 (0)