Skip to content

Commit 3c7e935

Browse files
author
Pavel Shashkin
committed
add logcosh activation
1 parent a6b0fee commit 3c7e935

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

src/activation.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,11 @@ softsign(x) = x / (one(x) + abs(x))
117117
See [Deep Sparse Rectifier Neural Networks](http://proceedings.mlr.press/v15/glorot11a/glorot11a.pdf).
118118
"""
119119
softplus(x) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x)))
120+
121+
122+
"""
123+
logcosh(x)
124+
125+
Return `log(cosh(x))` which is computed in a numerically stable way.
126+
"""
127+
logcosh(x::T) where T = x + softplus(-2x) - log(convert(T, 2))

test/activation.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using NNlib, Test
22

3-
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign];
3+
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logcosh];
44

55
function test_value_float_precision_preserving(a)
66
@testset "$(a): " begin
@@ -36,6 +36,7 @@ end
3636
@test softplus(-1e8) 0.0
3737
@test softsign(0.0) == 0.0
3838
@test selu(0.0) == 0.0
39+
@test logcosh(0.0) == log(cosh(0.0))
3940

4041
@test σ(1.0) == 1.0 / (1.0 + exp(-1.0))
4142
@test relu(1.0) == 1.0
@@ -46,6 +47,7 @@ end
4647
@test softplus(1.0) log(exp(1.0) + 1.0)
4748
@test softsign(1.0) == 0.5
4849
@test selu(1.0) == 1.0507009873554804934193349852946
50+
@test logcosh(1.0) log(cosh(1.0))
4951

5052
@test σ(-1.0) == 1.0 / (1.0 + exp(1.0))
5153
@test relu(-1.0) == 0.0
@@ -56,6 +58,7 @@ end
5658
@test softplus(-1.0) log(exp(-1.0) + 1.0)
5759
@test softsign(-1.0) == -0.5
5860
@test selu(-1.0) == 1.0507009873554804934193349852946 * 1.6732632423543772848170429916717 * (exp(-1.0) - 1.0)
61+
@test log(cosh(-1.0)) log(cosh(-1.0))
5962

6063
@testset "Float inference" begin
6164
test_value_float_precision_preserving.(ACTIVATION_FUNCTIONS)
@@ -125,4 +128,6 @@ end
125128
@eval @test logsigmoid.($T[-100_000, 100_000.]) $T[-100_000, 0.]
126129
end
127130
end
131+
132+
@test logcosh(1_000.0) + log(2) == 1_000.0
128133
end

0 commit comments

Comments
 (0)