Skip to content

Commit bc19012

Browse files
author
Avik Pal
committed
Fix merge conflicts with master
2 parents 8ced0c0 + 6b987ee commit bc19012

File tree

9 files changed

+97
-44
lines changed

9 files changed

+97
-44
lines changed

.codecov.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
comment: false

.travis.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ os:
55
- osx
66
julia:
77
- 1.0
8+
- 1.1
9+
- nightly
10+
matrix:
11+
allow_failures:
12+
- julia: nightly
813
notifications:
914
email: false
1015
git:
@@ -13,6 +18,8 @@ env:
1318
# Disable test fuzzing for the moment, as we're a little too slow for Travis
1419
- NNLIB_TEST_FUZZING=false
1520

16-
# Submit to Codecov
21+
# Submit to Codecov
1722
after_success:
18-
- julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Codecov.submit(process_folder())'
23+
- if [[ $TRAVIS_JULIA_VERSION = 1.1 ]] && [[ $TRAVIS_OS_NAME = linux ]]; then
24+
julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Codecov.submit(process_folder())';
25+
fi

Project.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,10 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
11-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1211
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
12+
13+
[extras]
14+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
15+
16+
[targets]
17+
test = ["Test"]

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ julia 1.0
22
Requires
33
MacroTools
44
BinaryProvider
5+
TimerOutputs

src/NNlib.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
module NNlib
22
using Requires, TimerOutputs
33

4+
const to = TimerOutput()
5+
46
# Include APIs
57
include("dim_helpers.jl")
68
include("activation.jl")

src/activation.jl

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ,
2-
logsigmoid
2+
logsigmoid, logcosh
33

44
"""
55
σ(x) = 1 / (1 + exp(-x))
66
77
Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation
88
function.
99
"""
10-
σ(x) = one(x) / (one(x) + exp(-x))
10+
σ(x::Real) = one(x) / (one(x) + exp(-x))
1111
const sigmoid = σ
1212

1313
# ForwardDiff numerical stability hack
14-
σ_stable(x) = ifelse(x < -80, zero(x), one(x) / (one(x) + exp(-x)))
14+
σ_stable(x::Real) = ifelse(x < -80, zero(x), one(x) / (one(x) + exp(-x)))
1515
σ(x::Float32) = σ_stable(x)
1616
@init @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
1717
σ(x::ForwardDiff.Dual{T,Float32}) where T = σ_stable(x)
@@ -27,11 +27,11 @@ Return `log(σ(x))` which is computed in a numerically stable way.
2727
-0.6931471805599453
2828
julia> logσ.([-100, -10, 100])
2929
3-element Array{Float64,1}:
30-
-100.0
31-
-10.000045398899218
30+
-100.0
31+
-10.000045398899218
3232
-3.720075976020836e-44
3333
"""
34-
logσ(x) = -softplus(-x)
34+
logσ(x::Real) = -softplus(-x)
3535
const logsigmoid = logσ
3636

3737

@@ -41,7 +41,7 @@ const logsigmoid = logσ
4141
[Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
4242
activation function.
4343
"""
44-
relu(x) = max(zero(x), x)
44+
relu(x::Real) = max(zero(x), x)
4545

4646

4747
"""
@@ -51,7 +51,7 @@ Leaky [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_ne
5151
activation function.
5252
You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`.
5353
"""
54-
leakyrelu(x, a = oftype(x/1, 0.01)) = max(a*x, x/1)
54+
leakyrelu(x::Real, a = oftype(x/1, 0.01)) = max(a*x, x/1)
5555

5656

5757
"""
@@ -71,7 +71,7 @@ elu(x, α = one(x)) = ifelse(x ≥ 0, x/1, α * (exp(x) - one(x)))
7171
[Gaussian Error Linear Unit](https://arxiv.org/pdf/1606.08415.pdf)
7272
activation function.
7373
"""
74-
function gelu(x)
74+
function gelu(x::Real)
7575
λ = oftype(x/1, (2/π))
7676
α = oftype(x/1, 0.044715)
7777
h = oftype(x/1, 0.5)
@@ -85,7 +85,7 @@ end
8585
Self-gated actvation function.
8686
See [Swish: a Self-Gated Activation Function](https://arxiv.org/pdf/1710.05941.pdf).
8787
"""
88-
swish(x) = x * σ(x)
88+
swish(x::Real) = x * σ(x)
8989

9090
"""
9191
selu(x) = λ * (x ≥ 0 ? x : α * (exp(x) - 1))
@@ -96,7 +96,7 @@ swish(x) = x * σ(x)
9696
Scaled exponential linear units.
9797
See [Self-Normalizing Neural Networks](https://arxiv.org/pdf/1706.02515.pdf).
9898
"""
99-
function selu(x)
99+
function selu(x::Real)
100100
λ = oftype(x/1, 1.0507009873554804934193349852946)
101101
α = oftype(x/1, 1.6732632423543772848170429916717)
102102
λ * ifelse(x > 0, x/1, α * (exp(x) - 1))
@@ -108,12 +108,26 @@ end
108108
109109
See [Quadratic Polynomials Learn Better Image Features](http://www.iro.umontreal.ca/~lisa/publications2/index.php/attachments/single/205).
110110
"""
111-
softsign(x) = x / (one(x) + abs(x))
111+
softsign(x::Real) = x / (one(x) + abs(x))
112112

113113

114114
"""
115115
softplus(x) = log(exp(x) + 1)
116116
117117
See [Deep Sparse Rectifier Neural Networks](http://proceedings.mlr.press/v15/glorot11a/glorot11a.pdf).
118118
"""
119-
softplus(x) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x)))
119+
softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x)))
120+
121+
122+
"""
123+
logcosh(x)
124+
125+
Return `log(cosh(x))` which is computed in a numerically stable way.
126+
"""
127+
logcosh(x::T) where T = x + softplus(-2x) - log(convert(T, 2))
128+
129+
# Provide an informative error message if activation functions are called with an array
130+
for f in (, :σ_stable, :logσ, :relu, :leakyrelu, :elu, :gelu, :swish, :selu, :softsign, :softplus, :logcosh)
131+
@eval $(f)(x::AbstractArray, args...) =
132+
error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.")
133+
end

src/impl/conv_im2col.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@ which should eliminate any need for large allocations within this method.
5050
# We invoke `@timeit_debug` on the outside of `im2col!()` because inference
5151
# doesn't like us putting it on the inside.
5252
@timeit_debug to "im2col!" im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
53-
col_ptr = pointer(col)
54-
w_ptr = pointer(w)
55-
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
56-
@timeit_debug to "gemm!" gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
53+
GC.@preserve col, w, y, begin
54+
col_ptr = pointer(col)
55+
w_ptr = pointer(w)
56+
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
57+
@timeit_debug to "gemm!" gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
58+
end
5759
end
5860
return y
5961
end
@@ -96,10 +98,12 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter
9698
# We invoke `@timeit_debug` on the outside of `im2col!()` because inference
9799
# doesn't like us putting it on the inside.
98100
@timeit_debug to "im2col!" im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
99-
col_ptr = pointer(col)
100-
dy_ptr = pointer(dy,(batch_idx - 1)*K*N + 1)
101-
dw_ptr = pointer(dw)
102-
@timeit_debug to "gemm!" gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
101+
GC.@preserve col, dw, dy, begin
102+
col_ptr = pointer(col)
103+
dy_ptr = pointer(dy,(batch_idx - 1)*K*N + 1)
104+
dw_ptr = pointer(dw)
105+
@timeit_debug to "gemm!" gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
106+
end
103107

104108
# Because we accumulate over batches in this loop, we must set `beta` equal
105109
# to `1.0` from this point on.
@@ -141,10 +145,12 @@ See the documentation for `conv_im2col!()` for explanation of other parameters.
141145
K = channels_out(cdims)
142146

143147
@inbounds for batch_idx in 1:size(dx, 5)
144-
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
145-
w_ptr = pointer(w)
146-
col_ptr = pointer(col)
147-
@timeit_debug to "gemm!" gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
148+
GC.@preserve col, w, dy, begin
149+
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
150+
w_ptr = pointer(w)
151+
col_ptr = pointer(col)
152+
@timeit_debug to "gemm!" gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
153+
end
148154
@timeit_debug to "col2im!" col2im!(view(dx, :, :, :, :, batch_idx), col, cdims)
149155
end
150156
return dx

src/impl/depthwiseconv_im2col.jl

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ depthwiseconv_im2col!
3535
# We do a separate convolution for each channel in x, as we must
3636
for c_in in 1:channels_in(cdims)
3737
# Walk each pointer forward as we process each input channel
38-
col_ptr = pointer(col, (c_in-1)*M*K+1)
39-
w_ptr = pointer(w, (c_in-1)*K*N+1)
40-
y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1)
41-
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
38+
GC.@preserve col, w, y, begin
39+
col_ptr = pointer(col, (c_in-1)*M*K+1)
40+
w_ptr = pointer(w, (c_in-1)*K*N+1)
41+
y_ptr = pointer(y, ((batch_idx - 1)*channels_in(cdims) + c_in - 1)*M*N + 1)
42+
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
43+
end
4244
end
4345
end
4446
return y
@@ -71,11 +73,12 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter
7173
# We do a separate convolution for each channel in x, as we must
7274
for c_in in 1:channels_in(cdims)
7375
# Walk each pointer forward as we process each input channel
74-
col_ptr = pointer(col, (c_in - 1)*M*K + 1)
75-
dy_ptr = pointer(dy, (batch_idx - 1)*N*K*channels_in(cdims) + (c_in - 1)*K*N + 1)
76-
dw_ptr = pointer(dw, (c_in - 1)*M*N + 1)
77-
78-
gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
76+
GC.@preserve col, dw, dy, begin
77+
col_ptr = pointer(col, (c_in - 1)*M*K + 1)
78+
dy_ptr = pointer(dy, (batch_idx - 1)*N*K*channels_in(cdims) + (c_in - 1)*K*N + 1)
79+
dw_ptr = pointer(dw, (c_in - 1)*M*N + 1)
80+
gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
81+
end
7982
end
8083

8184
# Because we accumulate over batches in this loop, we must set `beta` equal
@@ -107,13 +110,15 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter
107110
@inbounds for batch_idx in 1:size(dx)[end]
108111
# We do a separate convolution for each channel in x, as we must
109112
for cidx in 1:channels_in(cdims)
110-
# Walk each pointer forward as we process each input channel
111-
dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1)
112-
w_ptr = pointer(w, (cidx - 1)*K*N + 1)
113-
col_ptr = pointer(col, (cidx - 1)*M*N + 1)
114-
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
113+
GC.@preserve col, w, dy, begin
114+
# Walk each pointer forward as we process each input channel
115+
dy_ptr = pointer(dy, (batch_idx - 1)*M*K*channels_in(cdims)+(cidx - 1)*K*M + 1)
116+
w_ptr = pointer(w, (cidx - 1)*K*N + 1)
117+
col_ptr = pointer(col, (cidx - 1)*M*N + 1)
118+
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
119+
end
115120
end
116121
@timeit_debug to "col2im!" col2im!(view(dx, :, :, :, :, batch_idx), col, cdims)
117122
end
118123
return dx
119-
end
124+
end

test/activation.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using NNlib, Test
22

3-
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign];
3+
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logcosh];
44

55
function test_value_float_precision_preserving(a)
66
@testset "$(a): " begin
@@ -36,6 +36,7 @@ end
3636
@test softplus(-1e8) 0.0
3737
@test softsign(0.0) == 0.0
3838
@test selu(0.0) == 0.0
39+
@test logcosh(0.0) == log(cosh(0.0))
3940

4041
@test σ(1.0) == 1.0 / (1.0 + exp(-1.0))
4142
@test relu(1.0) == 1.0
@@ -46,6 +47,7 @@ end
4647
@test softplus(1.0) log(exp(1.0) + 1.0)
4748
@test softsign(1.0) == 0.5
4849
@test selu(1.0) == 1.0507009873554804934193349852946
50+
@test logcosh(1.0) log(cosh(1.0))
4951

5052
@test σ(-1.0) == 1.0 / (1.0 + exp(1.0))
5153
@test relu(-1.0) == 0.0
@@ -56,11 +58,19 @@ end
5658
@test softplus(-1.0) log(exp(-1.0) + 1.0)
5759
@test softsign(-1.0) == -0.5
5860
@test selu(-1.0) == 1.0507009873554804934193349852946 * 1.6732632423543772848170429916717 * (exp(-1.0) - 1.0)
61+
@test log(cosh(-1.0)) log(cosh(-1.0))
5962

6063
@testset "Float inference" begin
6164
test_value_float_precision_preserving.(ACTIVATION_FUNCTIONS)
6265
end
6366

67+
@testset "Array input" begin
68+
x = rand(5)
69+
for a in ACTIVATION_FUNCTIONS
70+
@test_throws ErrorException a(x)
71+
end
72+
end
73+
6474
@testset "Test Integer64 and Integer32 inputs will force Float64 outputs" begin
6575
test_value_int_input_forces_float64.(filter(x -> x != relu, ACTIVATION_FUNCTIONS))
6676

@@ -125,4 +135,6 @@ end
125135
@eval @test logsigmoid.($T[-100_000, 100_000.]) $T[-100_000, 0.]
126136
end
127137
end
138+
139+
@test logcosh(1_000.0) + log(2) == 1_000.0
128140
end

0 commit comments

Comments
 (0)