Skip to content

Commit b896e33

Browse files
Added Mish activation (#145)
* Added Mish activation Referring to issue - #144 * Added tests for Mish * Fixed Mish Test Co-authored-by: Manjunath Bhat <[email protected]>
1 parent 1c35815 commit b896e33

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/activation.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ,
2-
logsigmoid, logcosh
2+
logsigmoid, logcosh, mish
33

44
"""
55
σ(x) = 1 / (1 + exp(-x))
@@ -131,3 +131,12 @@ for f in (:σ, :σ_stable, :logσ, :relu, :leakyrelu, :elu, :gelu, :swish, :selu
131131
@eval $(f)(x::AbstractArray, args...) =
132132
error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.")
133133
end
134+
135+
136+
"""
137+
mish(x) = x * tanh(softplus(x))
138+
139+
Self Regularized Non-Monotonic Neural Activation Function
140+
See [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681).
141+
"""
142+
mish(x::Real) = x * tanh(softplus(x))

test/activation.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ end
3737
@test softsign(0.0) == 0.0
3838
@test selu(0.0) == 0.0
3939
@test logcosh(0.0) == log(cosh(0.0))
40+
@test mish(0.0) == 0.0
4041

4142
@test σ(1.0) == 1.0 / (1.0 + exp(-1.0))
4243
@test relu(1.0) == 1.0
@@ -48,6 +49,7 @@ end
4849
@test softsign(1.0) == 0.5
4950
@test selu(1.0) == 1.0507009873554804934193349852946
5051
@test logcosh(1.0) log(cosh(1.0))
52+
@test mish(1.0) tanh(log(1.0 + exp(1.0)))
5153

5254
@test σ(-1.0) == 1.0 / (1.0 + exp(1.0))
5355
@test relu(-1.0) == 0.0
@@ -59,6 +61,7 @@ end
5961
@test softsign(-1.0) == -0.5
6062
@test selu(-1.0) == 1.0507009873554804934193349852946 * 1.6732632423543772848170429916717 * (exp(-1.0) - 1.0)
6163
@test log(cosh(-1.0)) log(cosh(-1.0))
64+
@test mish(-1.0) -tanh(log(1.0 + exp(-1.0)))
6265

6366
@testset "Float inference" begin
6467
test_value_float_precision_preserving.(ACTIVATION_FUNCTIONS)
@@ -127,6 +130,13 @@ end
127130
@test elu(-4) (exp(-4) - 1)
128131
end
129132

133+
@testset "mish" begin
134+
@test mish(-5) -0.033576237730161704
135+
@test mish(9) == 9*tanh(log(1 + exp(9)))
136+
xs = Float32[1 2 3; 1000 2000 3000]
137+
@test typeof(mish.(xs)) == typeof(xs)
138+
end
139+
130140
@test leakyrelu( 0.4,0.3) 0.4
131141
@test leakyrelu(-0.4,0.3) -0.12
132142

0 commit comments

Comments
 (0)