Skip to content

Commit 6131a17

Browse files
committed
Allow only real input for activation functions
1 parent d07ac0b commit 6131a17

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

src/NNlib.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ module NNlib
22

33
using Requires, Libdl
44

5+
using MacroTools: @capture
6+
57
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ, logsigmoid,
68
softmax, logsoftmax, maxpool, meanpool
79

src/activation.jl

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
1+
macro onlyreal(ex)
2+
@capture(ex, (f_(x, a__) = body_) | (function f_(x, a__) body_ end)) ||
3+
error("expected a function with initial argument `x`")
4+
5+
errmsg = "Use explicit invocations such as `$(f).(x)` to apply activation functions to tensors!"
6+
7+
quote
8+
Base.@__doc__ $(f)(x::Real, $(a...)) = $body
9+
$(f)(x::AbstractArray, $(a...)) = error($errmsg)
10+
end |> esc
11+
end
12+
113
"""
214
σ(x) = 1 / (1 + exp(-x))
315
416
Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation
517
function.
618
"""
7-
σ(x) = one(x) / (one(x) + exp(-x))
19+
@onlyreal σ(x) = one(x) / (one(x) + exp(-x))
820

921
const sigmoid = σ
1022

1123
# ForwardDiff numerical stability hack
12-
σ_stable(x) = ifelse(x < -80, zero(x), one(x) / (one(x) + exp(-x)))
24+
@onlyreal σ_stable(x) = ifelse(x < -80, zero(x), one(x) / (one(x) + exp(-x)))
1325

1426
σ(x::Float32) = σ_stable(x)
1527

@@ -30,7 +42,7 @@ Return `log(σ(x))` which is computed in a numerically stable way.
3042
-10.0
3143
-0.0
3244
"""
33-
function logσ(x)
45+
@onlyreal function logσ(x)
3446
max_v = max(zero(x), -x)
3547
z = exp(-max_v) + exp(-x-max_v)
3648
-(max_v + log(z))
@@ -44,7 +56,7 @@ const logsigmoid = logσ
4456
[Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
4557
activation function.
4658
"""
47-
relu(x) = max(zero(x), x)
59+
@onlyreal relu(x) = max(zero(x), x)
4860

4961

5062
"""
@@ -54,7 +66,7 @@ Leaky [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_ne
5466
activation function.
5567
You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`.
5668
"""
57-
leakyrelu(x, a = oftype(x/1, 0.01)) = max(a*x, x/1)
69+
@onlyreal leakyrelu(x, a = oftype(x/1, 0.01)) = max(a*x, x/1)
5870

5971
"""
6072
elu(x, α = 1) =
@@ -64,15 +76,15 @@ Exponential Linear Unit activation function.
6476
See [Fast and Accurate Deep Network Learning by Exponential Linear Units](https://arxiv.org/abs/1511.07289).
6577
You can also specify the coefficient explicitly, e.g. `elu(x, 1)`.
6678
"""
67-
elu(x, α = one(x)) = ifelse(x 0, x/1, α * (exp(x) - one(x)))
79+
@onlyreal elu(x, α = one(x)) = ifelse(x 0, x/1, α * (exp(x) - one(x)))
6880

6981
"""
7082
gelu(x) = 0.5x*(1 + tanh(√(2/π)*(x + 0.044715x^3)))
7183
7284
[Gaussian Error Linear Unit](https://arxiv.org/pdf/1606.08415.pdf)
7385
activation function.
7486
"""
75-
function gelu(x)
87+
@onlyreal function gelu(x)
7688
λ = oftype(x/1, (2/π))
7789
α = oftype(x/1, 0.044715)
7890
h = oftype(x/1, 0.5)
@@ -86,7 +98,7 @@ end
8698
Self-gated actvation function.
8799
See [Swish: a Self-Gated Activation Function](https://arxiv.org/pdf/1710.05941.pdf).
88100
"""
89-
swish(x) = x * σ(x)
101+
@onlyreal swish(x) = x * σ(x)
90102

91103
"""
92104
selu(x) = λ * (x ≥ 0 ? x : α * (exp(x) - 1))
@@ -97,7 +109,7 @@ swish(x) = x * σ(x)
97109
Scaled exponential linear units.
98110
See [Self-Normalizing Neural Networks](https://arxiv.org/pdf/1706.02515.pdf).
99111
"""
100-
function selu(x)
112+
@onlyreal function selu(x)
101113
λ = oftype(x/1, 1.0507009873554804934193349852946)
102114
α = oftype(x/1, 1.6732632423543772848170429916717)
103115
λ * ifelse(x > 0, x/1, α * (exp(x) - 1))
@@ -108,12 +120,12 @@ end
108120
109121
See [Quadratic Polynomials Learn Better Image Features](http://www.iro.umontreal.ca/~lisa/publications2/index.php/attachments/single/205).
110122
"""
111-
softsign(x) = x / (one(x) + abs(x))
123+
@onlyreal softsign(x) = x / (one(x) + abs(x))
112124

113125

114126
"""
115127
softplus(x) = log(exp(x) + 1)
116128
117129
See [Deep Sparse Rectifier Neural Networks](http://proceedings.mlr.press/v15/glorot11a/glorot11a.pdf).
118130
"""
119-
softplus(x) = log1p(exp(x))
131+
@onlyreal softplus(x) = log1p(exp(x))

test/activation.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,17 @@ end
8282
end
8383
end
8484

85+
@testset "Array input" begin
86+
x = rand(5)
87+
88+
for a in ACTIVATION_FUNCTIONS
89+
if a == leakyrelu || a == elu
90+
@test_throws ErrorException a(x, 1.0)
91+
else
92+
@test_throws ErrorException a(x)
93+
end
94+
end
95+
end
8596

8697
xs = rand(5,5)
8798

0 commit comments

Comments
 (0)