Skip to content

Commit b696529

Browse files
authored
Merge branch 'master' into master
2 parents 92d322b + a470e99 commit b696529

16 files changed

+288
-103
lines changed

.travis.yml

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,6 @@ notifications:
1010
git:
1111
depth: 99999999
1212

13-
## uncomment the following lines to allow failures on nightly julia
14-
## (tests will run but not make your overall status red)
15-
#matrix:
16-
# allow_failures:
17-
# - julia: nightly
18-
19-
## uncomment and modify the following lines to manually install system packages
20-
#addons:
21-
# apt: # apt-get for linux
22-
# packages:
23-
# - gfortran
24-
#before_script: # homebrew for mac
25-
# - if [ $TRAVIS_OS_NAME = osx ]; then brew install gcc; fi
26-
27-
## uncomment the following lines to override the default test script
28-
#script:
29-
# - julia -e 'Pkg.clone(pwd()); Pkg.build("NNlib"); Pkg.test("NNlib"; coverage=true)'
13+
# Submit to Codecov
14+
after_success:
15+
- julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Codecov.submit(process_folder())'

Manifest.toml

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
[[Base64]]
2+
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
3+
4+
[[Compat]]
5+
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
6+
git-tree-sha1 = "ff2595695fc4f14427358ce2593f867085c45dcb"
7+
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
8+
version = "1.2.0"
9+
10+
[[Dates]]
11+
deps = ["Printf"]
12+
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
13+
14+
[[DelimitedFiles]]
15+
deps = ["Mmap"]
16+
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
17+
18+
[[Distributed]]
19+
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
20+
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
21+
22+
[[InteractiveUtils]]
23+
deps = ["LinearAlgebra", "Markdown"]
24+
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
25+
26+
[[LibGit2]]
27+
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
28+
29+
[[Libdl]]
30+
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
31+
32+
[[LinearAlgebra]]
33+
deps = ["Libdl"]
34+
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
35+
36+
[[Logging]]
37+
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
38+
39+
[[MacroTools]]
40+
deps = ["Compat"]
41+
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
42+
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
43+
version = "0.4.4"
44+
45+
[[Markdown]]
46+
deps = ["Base64"]
47+
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
48+
49+
[[Mmap]]
50+
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
51+
52+
[[Pkg]]
53+
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
54+
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
55+
56+
[[Printf]]
57+
deps = ["Unicode"]
58+
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
59+
60+
[[REPL]]
61+
deps = ["InteractiveUtils", "Markdown", "Sockets"]
62+
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
63+
64+
[[Random]]
65+
deps = ["Serialization"]
66+
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
67+
68+
[[Requires]]
69+
deps = ["Test"]
70+
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
71+
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
72+
version = "0.5.2"
73+
74+
[[SHA]]
75+
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
76+
77+
[[Serialization]]
78+
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
79+
80+
[[SharedArrays]]
81+
deps = ["Distributed", "Mmap", "Random", "Serialization"]
82+
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
83+
84+
[[Sockets]]
85+
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
86+
87+
[[SparseArrays]]
88+
deps = ["LinearAlgebra", "Random"]
89+
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
90+
91+
[[Statistics]]
92+
deps = ["LinearAlgebra", "SparseArrays"]
93+
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
94+
95+
[[Test]]
96+
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
97+
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
98+
99+
[[UUIDs]]
100+
deps = ["Random"]
101+
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
102+
103+
[[Unicode]]
104+
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

Project.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
name = "NNlib"
2+
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3+
4+
[deps]
5+
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
6+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
8+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
9+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# NNlib
22

3-
[![Build Status](https://travis-ci.org/FluxML/NNlib.jl.svg?branch=master)](https://travis-ci.org/FluxML/NNlib.jl) [![Build status](https://ci.appveyor.com/api/projects/status/wo2wkv1l9cj548uh?svg=true)](https://ci.appveyor.com/project/one-more-minute/nnlib-jl)
3+
[![Build Status](https://travis-ci.org/FluxML/NNlib.jl.svg?branch=master)](https://travis-ci.org/FluxML/NNlib.jl) [![Build status](https://ci.appveyor.com/api/projects/status/wo2wkv1l9cj548uh?svg=true)](https://ci.appveyor.com/project/one-more-minute/nnlib-jl) [![Coverage](https://codecov.io/gh/FluxML/NNlib/branch/master/graph/badge.svg)](https://codecov.io/gh/FluxML/NNlib)
4+
45

56
This package will provide a library of functions useful for ML, such as softmax, sigmoid, convolutions and pooling. It doesn't provide any other "high-level" functionality like layers or AD.
67

appveyor.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,6 @@ test_script:
4040
# on_success:
4141
# - echo "%JL_CODECOV_SCRIPT%"
4242
# - C:\julia\bin\julia -e "%JL_CODECOV_SCRIPT%"
43+
44+
after_test:
45+
- C:\julia\bin\julia -e "using Pkg; Pkg.add(\"Coverage\"); using Coverage; Codecov.submit(process_folder())"

src/NNlib.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module NNlib
22

33
using Requires, Libdl
44

5-
export σ, sigmoid, relu, leakyrelu, elu, swish, selu, softplus, softsign, logσ, logsigmoid,
5+
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ, logsigmoid,
66
softmax, logsoftmax, maxpool, meanpool
77

88
include("numeric.jl")

src/activation.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,20 @@ You can also specify the coefficient explicitly, e.g. `elu(x, 1)`.
6666
"""
6767
elu(x, α = one(x)) = ifelse(x 0, x/1, α * (exp(x) - one(x)))
6868

69+
"""
70+
gelu(x) = 0.5x*(1 + tanh(√(2/π)*(x + 0.044715x^3)))
71+
72+
[Gaussian Error Linear Unit](https://arxiv.org/pdf/1606.08415.pdf)
73+
activation function.
74+
"""
75+
function gelu(x)
76+
λ = oftype(x/1, (2/π))
77+
α = oftype(x/1, 0.044715)
78+
h = oftype(x/1, 0.5)
79+
h * x * (one(x) + tanh* (x + α * x^3)))
80+
end
81+
82+
6983
"""
7084
swish(x) = x * σ(x)
7185

src/conv.jl

Lines changed: 71 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,74 @@ function cdims(x::NTuple{N}, w::NTuple{N}, pad, stride) where N
1717
end
1818
end
1919

20+
21+
# Conv Transpose dims
22+
23+
function ctdims(x::NTuple{N}, w::NTuple{N}, pad, stride, dilation) where N
24+
ntuple(Val(N)) do i
25+
if i < N-1
26+
(x[i] - 1) * stride[i] + dilation[i] * (w[i] - 1) - 2*pad[i] + 1
27+
elseif i == N-1
28+
w[N-1]
29+
else # i == N
30+
x[N]
31+
end
32+
end
33+
end
34+
35+
36+
# Kernel dims
37+
38+
function wdims(x::NTuple{N}, y::NTuple{N}, pad, stride, dilation) where N
39+
ntuple(Val(N)) do i
40+
if i < N-1
41+
1 + div((1 - y[i]) * stride[i] + x[i] + 2pad[i] - 1, dilation[i])
42+
elseif i == N-1
43+
x[i]
44+
else # i == N
45+
y[i-1]
46+
end
47+
end
48+
end
49+
2050
# Interface
2151

2252
head(x) = reverse(Base.tail(reverse(x)))
2353
padtuple(x::Tuple,p::Integer) = map(_->p, head(head(x)))
2454
padtuple(x::Tuple,p::Tuple) = p
2555
padtuple(x::AbstractArray,p) = padtuple(size(x),p)
2656

27-
function conv(x::AbstractArray, w::AbstractArray; pad = 0, stride = 1, dilation = 1)
57+
function conv(x::AbstractArray, w::AbstractArray; size=nothing, pad = 0, stride = 1, dilation = 1)
2858
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
29-
conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
30-
x, w, pad = pad_, stride = stride_, dilation = dilation)
59+
if size === nothing
60+
size = cdims(Base.size(x), dilation_dims(w, dilation), pad_, stride_)
61+
end
62+
conv!(similar(x, size), x, w, pad = pad_, stride = stride_, dilation = dilation)
3163
end
3264

33-
function crosscor(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
65+
function crosscor(x::A, w::A; size=nothing, pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
3466
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
35-
crosscor!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
36-
x, w, pad = pad_, stride = stride_, dilation = dilation)
67+
if size === nothing
68+
size = cdims(Base.size(x), dilation_dims(w, dilation), pad_, stride_)
69+
end
70+
crosscor!(similar(x, size), x, w, pad = pad_, stride = stride_, dilation = dilation)
3771
end
3872

39-
∇conv_data(dy::AbstractArray, x::AbstractArray, w::AbstractArray; pad = 0, stride = 1, dilation = 1, flipkernel = 0) =
40-
∇conv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel)
73+
function ∇conv_data(dy::AbstractArray, w::AbstractArray; size=nothing, pad = 0, stride = 1, dilation = 1, flipkernel = 0)
74+
pad_, stride_, dilation_ = padtuple(dy, pad), padtuple(dy, stride), padtuple(dy, dilation)
75+
if size === nothing
76+
size = ctdims(Base.size(dy), Base.size(w), pad_, stride_, dilation_)
77+
end
78+
∇conv_data!(similar(dy, size), dy, w, pad = pad_, stride = stride_, dilation = dilation_, flipkernel=flipkernel)
79+
end
4180

42-
∇conv_filter(dy::AbstractArray, x::AbstractArray, w::AbstractArray; pad = 0, stride = 1, dilation = 1, flipkernel=0) =
43-
∇conv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel)
81+
function ∇conv_filter(dy::AbstractArray, x::AbstractArray; size = nothing, pad = 0, stride = 1, dilation = 1, flipkernel=0)
82+
pad_, stride_, dilation_ = padtuple(dy, pad), padtuple(dy, stride), padtuple(dy, dilation)
83+
if size === nothing
84+
size = wdims(Base.size(x), Base.size(dy), pad_, stride_, dilation_)
85+
end
86+
∇conv_filter!(zero(similar(dy, size)), dy, x; pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel)
87+
end
4488

4589
# N-D dispatch
4690

@@ -56,18 +100,16 @@ function crosscor!(y::AbstractArray, x::AbstractArray, w::AbstractArray;
56100
conv!(y, x, w, pad=pad, stride=stride, dilation=dilation, flipkernel=1)
57101
end
58102

59-
function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
60-
x::AbstractArray{T,3}, w::AbstractArray{T,3};
103+
function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3}, x::AbstractArray{T,3};
61104
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T
62-
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dw, dy, x, w))
105+
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dw, dy, x))
63106
∇conv_filter!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), flipkernel=flipkernel)
64107
return dw
65108
end
66109

67-
function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
68-
x::AbstractArray{T,3}, w::AbstractArray{T,3};
69-
pad = 0, stride = 1, dilation = 1, flipkernel = 0) where T
70-
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dx, dy, x, w))
110+
function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3}, w::AbstractArray{T,3};
111+
pad = 0, stride = 1, dilation = 1, flipkernel = 0) where T
112+
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dx, dy, w))
71113
∇conv_data!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation..., 1), flipkernel = flipkernel)
72114
return dx
73115
end
@@ -76,25 +118,25 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
76118
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
77119
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
78120

79-
∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
121+
∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4};
80122
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
81-
conv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
123+
conv2d_grad_w!(dw, x, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
82124

83-
∇conv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
125+
∇conv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, w::AbstractArray{T,4};
84126
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
85-
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
127+
conv2d_grad_x!(dx, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
86128

87129
conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
88130
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
89131
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
90132

91-
∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
133+
∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5};
92134
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
93-
conv3d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
135+
conv3d_grad_w!(dw, x, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
94136

95-
∇conv_data!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
137+
∇conv_data!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, w::AbstractArray{T,5};
96138
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
97-
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
139+
conv3d_grad_x!(dx, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
98140

99141
# Depthwise Conv
100142

@@ -217,10 +259,8 @@ meanpool_cpu!(y::AbstractArray{<:Real,5}, x::AbstractArray{<:Real,5}, k::Dims{3}
217259
meanpool3d_grad!(dx, dy, y, x,
218260
window = k, padding = pad, stride = stride)
219261

220-
# Deprecated 0.3
221-
222-
export conv2d, maxpool2d, avgpool2d
262+
# Deprecated
223263

224-
@deprecate conv2d(x, w; kw...) NNlib.conv(x, w; kw...)
225-
@deprecate maxpool2d(x::AbstractArray{<:Real,4}, k::Integer) maxpool(x, (k,k))
226-
@deprecate meanpool2d(x::AbstractArray{<:Real,4}, k::Integer) meanpool(x, (k,k))
264+
# 0.4.2
265+
@deprecate ∇conv_data(dy::A, x::A, w::A; kw...) where A<:AbstractArray ∇conv_data(dy, w; size=size(x), kw...)
266+
@deprecate ∇conv_filter(dy::A, x::A, w::A; kw...) where A<:AbstractArray ∇conv_filter(dy, x; size=size(w), kw...)

0 commit comments

Comments
 (0)