Skip to content

Commit fe52493

Browse files
committed
Merge branch 'master' into conv
2 parents e7fd08d + 41d91eb commit fe52493

File tree

5 files changed

+112
-65
lines changed

5 files changed

+112
-65
lines changed

appveyor.yml

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
environment:
22
matrix:
3-
- JULIA_URL: "https://julialang-s3.julialang.org/bin/winnt/x86/0.6/julia-0.6-latest-win32.exe"
4-
- JULIA_URL: "https://julialang-s3.julialang.org/bin/winnt/x64/0.6/julia-0.6-latest-win64.exe"
5-
# - JULIA_URL: "https://julialangnightlies-s3.julialang.org/bin/winnt/x86/julia-latest-win32.exe"
6-
# - JULIA_URL: "https://julialangnightlies-s3.julialang.org/bin/winnt/x64/julia-latest-win64.exe"
7-
8-
## uncomment the following lines to allow failures on nightly julia
9-
## (tests will run but not make your overall status red)
10-
#matrix:
11-
# allow_failures:
12-
# - JULIA_URL: "https://julialangnightlies-s3.julialang.org/bin/winnt/x86/julia-latest-win32.exe"
13-
# - JULIA_URL: "https://julialangnightlies-s3.julialang.org/bin/winnt/x64/julia-latest-win64.exe"
3+
- julia_version: 1
4+
- julia_version: nightly
5+
6+
platform:
7+
- x86 # 32-bit
8+
- x64 # 64-bit
9+
10+
# # Uncomment the following lines to allow failures on nightly julia
11+
# # (tests will run but not make your overall status red)
12+
# matrix:
13+
# allow_failures:
14+
# - julia_version: nightly
1415

1516
branches:
1617
only:
@@ -24,24 +25,18 @@ notifications:
2425
on_build_status_changed: false
2526

2627
install:
27-
- ps: "[System.Net.ServicePointManager]::SecurityProtocol = [System.Net.SecurityProtocolType]::Tls12"
28-
# If there's a newer build queued for the same PR, cancel this one
29-
- ps: if ($env:APPVEYOR_PULL_REQUEST_NUMBER -and $env:APPVEYOR_BUILD_NUMBER -ne ((Invoke-RestMethod `
30-
https://ci.appveyor.com/api/projects/$env:APPVEYOR_ACCOUNT_NAME/$env:APPVEYOR_PROJECT_SLUG/history?recordsNumber=50).builds | `
31-
Where-Object pullRequestId -eq $env:APPVEYOR_PULL_REQUEST_NUMBER)[0].buildNumber) { `
32-
throw "There are newer queued builds for this pull request, failing early." }
33-
# Download most recent Julia Windows binary
34-
- ps: (new-object net.webclient).DownloadFile(
35-
$env:JULIA_URL,
36-
"C:\projects\julia-binary.exe")
37-
# Run installer silently, output to C:\projects\julia
38-
- C:\projects\julia-binary.exe /S /D=C:\projects\julia
28+
- ps: iex ((new-object net.webclient).DownloadString("https://raw.githubusercontent.com/JuliaCI/Appveyor.jl/version-1/bin/install.ps1"))
3929

4030
build_script:
41-
# Need to convert from shallow to complete for Pkg.clone to work
42-
- IF EXIST .git\shallow (git fetch --unshallow)
43-
- C:\projects\julia\bin\julia -e "versioninfo();
44-
Pkg.clone(pwd(), \"NNlib\"); Pkg.build(\"NNlib\")"
31+
- echo "%JL_BUILD_SCRIPT%"
32+
- C:\julia\bin\julia -e "%JL_BUILD_SCRIPT%"
4533

4634
test_script:
47-
- C:\projects\julia\bin\julia -e "Pkg.test(\"NNlib\")"
35+
- echo "%JL_TEST_SCRIPT%"
36+
- C:\julia\bin\julia -e "%JL_TEST_SCRIPT%"
37+
38+
# # Uncomment to support code coverage upload. Should only be enabled for packages
39+
# # which would have coverage gaps without running on Windows
40+
# on_success:
41+
# - echo "%JL_CODECOV_SCRIPT%"
42+
# - C:\julia\bin\julia -e "%JL_CODECOV_SCRIPT%"

src/conv.jl

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,62 +30,73 @@ function conv(x::A, w::B; pad = 0, stride = 1, dilation = 1) where {A<:AbstractA
3030
x, w, pad = pad_, stride = stride_, dilation = dilation)
3131
end
3232

33-
∇conv_data(dy::A, x::B, w::C; pad = 0, stride = 1, dilation = 1) where {A<:AbstractArray, B<:AbstractArray, C<:AbstractArray} =
34-
∇conv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
33+
function crosscor(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
34+
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
35+
crosscor!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
36+
x, w, pad = pad_, stride = stride_, dilation = dilation)
37+
end
3538

36-
∇conv_filter(dy::A, x::B, w::C; pad = 0, stride = 1, dilation = 1) where {A<:AbstractArray, B<:AbstractArray, C<:AbstractArray} =
37-
∇conv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, dilation = dilation)
39+
∇conv_data(dy::A, x::B, w::C; pad = 0, stride = 1, dilation = 1, flipkernel = 0) where {A<:AbstractArray, B<:AbstractArray, C<:AbstractArray} =
40+
∇conv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel)
41+
42+
∇conv_filter(dy::A, x::B, w::C; pad = 0, stride = 1, dilation = 1, flipkernel=0) where {A<:AbstractArray, B<:AbstractArray, C<:AbstractArray} =
43+
∇conv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel)
3844

3945
# N-D dispatch
4046

4147
function conv!(y::AbstractArray{T,3}, x::AbstractArray{T,3}, w::AbstractArray{T,3};
42-
pad = 0, stride = 1, dilation = 1) where T
48+
pad = 0, stride = 1, dilation = 1, flipkernel =0) where T
4349
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (y, x, w))
44-
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
50+
conv!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), flipkernel=flipkernel)
4551
return y
4652
end
4753

54+
function crosscor!(y::AbstractArray, x::AbstractArray, w::AbstractArray;
55+
pad = 0, stride = 1, dilation = 1)
56+
conv!(y, x, w, pad=pad, stride=stride, dilation=dilation, flipkernel=1)
57+
end
58+
4859
function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
4960
x::AbstractArray{T,3}, w::AbstractArray{T,3};
50-
pad = 0, stride = 1, dilation = 1) where T
61+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T
5162
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dw, dy, x, w))
52-
∇conv_filter!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
63+
∇conv_filter!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), flipkernel=flipkernel)
5364
return dw
5465
end
5566

5667
function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
5768
x::AbstractArray{T,3}, w::AbstractArray{T,3};
58-
pad = 0, stride = 1, dilation = 1) where T
69+
pad = 0, stride = 1, dilation = 1, flipkernel = 0) where T
5970
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dx, dy, x, w))
60-
∇conv_data!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation..., 1))
71+
∇conv_data!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation..., 1), flipkernel = flipkernel)
6172
return dx
6273
end
6374

6475
conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
65-
pad = 0, stride = 1, dilation = 1) where T =
66-
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
76+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
77+
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
6778

6879
∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
69-
pad = 0, stride = 1, dilation = 1) where T =
70-
conv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
80+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
81+
conv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
7182

7283
∇conv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
73-
pad = 0, stride = 1, dilation = 1) where T =
74-
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
84+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
85+
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
7586

7687
conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
77-
pad = 0, stride = 1, dilation = 1) where T =
78-
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
88+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
89+
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
7990

8091
∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
81-
pad = 0, stride = 1, dilation = 1) where T =
82-
conv3d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
92+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
93+
conv3d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
8394

8495
∇conv_data!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
85-
pad = 0, stride = 1, dilation = 1) where T =
86-
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
96+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
97+
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
8798

88-
# Depthwise Conv
99+
# Depthwise Conv
89100

90101
function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
91102
((x[1] + 2 * pad[1] - w[1])÷stride[1] + 1,(x[2] + 2 * pad[2] - w[2])÷stride[2] + 1,w[3]*w[4],x[4])
@@ -96,23 +107,32 @@ function depthwiseconv(x::A, w::B; pad = 0, stride = 1) where {A<:AbstractArray,
96107
depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
97108
end
98109

110+
function depthwisecrosscor(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
111+
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
112+
depthwisecrosscor!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
113+
end
114+
99115
depthwiseconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
116+
pad = 0, stride = 1, flipkernel=0) where T =
117+
depthwiseconv2d!(y, x, w, padding = pad, stride = stride, mode= flipkernel)
118+
119+
depthwisecrosscor!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
100120
pad = 0, stride = 1) where T =
101-
depthwiseconv2d!(y, x, w, padding = pad, stride = stride)
121+
depthwiseconv!(y, x, w, pad = pad, stride = stride, flipkernel=1)
102122

103-
∇depthwiseconv_data(dy::A, x::B, w::C; pad = 0, stride = 1) where {A<:AbstractArray, B<:AbstractArray, C<:AbstractArray} =
104-
∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)
123+
∇depthwiseconv_data(dy::A, x::B, w::C; pad = 0, stride = 1, flipkernel=0) where {A<:AbstractArray, B<:AbstractArray, C<:AbstractArray} =
124+
∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride, flipkernel=flipkernel)
105125

106-
∇depthwiseconv_filter(dy::A, x::B, w::C; pad = 0, stride = 1) where {A<:AbstractArray, B<:AbstractArray, C<:AbstractArray} =
107-
∇depthwiseconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride)
126+
∇depthwiseconv_filter(dy::A, x::B, w::C; pad = 0, stride = 1, flipkernel=0) where {A<:AbstractArray, B<:AbstractArray, C<:AbstractArray} =
127+
∇depthwiseconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, flipkernel=flipkernel)
108128

109129
∇depthwiseconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
110-
pad = 0, stride = 1) where T =
111-
depthwiseconv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride)
130+
pad = 0, stride = 1, flipkernel=0) where T =
131+
depthwiseconv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, mode=flipkernel)
112132

113133
∇depthwiseconv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
114-
pad = 0, stride = 1) where T =
115-
depthwiseconv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride)
134+
pad = 0, stride = 1, flipkernel=0) where T =
135+
depthwiseconv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, mode=flipkernel)
116136

117137
# Pooling
118138

src/impl/conv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
function psize(p, x)
33
nd = ndims(x)-2
44
if isa(p,Number)
5-
fill(Int(p),nd)
5+
ntuple(_->Int(p), nd)
66
elseif length(p)==nd
7-
collect(Int,p)
7+
tuple(p...)
88
else
99
throw(DimensionMismatch("psize: $p $nd"))
1010
end

src/linalg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ for (gemm, elty) in ((:dgemm_,:Float64), (:sgemm_,:Float32))
1818
if transA=='N'; lda=M; else; lda=K; end
1919
if transB=='N'; ldb=K; else; ldb=N; end
2020
ldc = M;
21-
ccall((@blasfunc(dgemm_), libblas), Nothing,
21+
ccall((@blasfunc($(gemm)), libblas), Nothing,
2222
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
23-
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
24-
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
23+
Ref{BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{BlasInt},
24+
Ptr{$elty}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
2525
Ref{BlasInt}),
2626
transA, transB, M, N, K,
2727
alpha, A, lda, B, ldb, beta, C, ldc)

test/conv.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,20 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwisec
1616
49 99 149;
1717
59 109 159.]
1818

19+
@test dropdims(conv(Float32.(x), Float32.(w)), dims=(3,4)) == Float32.([
20+
29 79 129;
21+
39 89 139;
22+
49 99 149;
23+
59 109 159.])
24+
1925
@test dropdims(conv(x, w; stride=2), dims = (3,4)) == [
2026
29 129;
2127
49 149.]
2228

29+
@test dropdims(conv(Float32.(x), Float32.(w); stride=2), dims = (3,4)) == Float32.([
30+
29 129;
31+
49 149.])
32+
2333
@test dropdims(conv(x, w; pad=1), dims = (3,4)) == [
2434
1.0 9.0 29.0 49.0 48.0;
2535
4.0 29.0 79.0 129.0 115.0;
@@ -29,6 +39,15 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwisec
2939
10.0 40.0 70.0 100.0 80.0
3040
]
3141

42+
@test dropdims(conv(Float32.(x), Float32.(w); pad=1), dims = (3,4)) == Float32.([
43+
1.0 9.0 29.0 49.0 48.0;
44+
4.0 29.0 79.0 129.0 115.0;
45+
7.0 39.0 89.0 139.0 122.0;
46+
10.0 49.0 99.0 149.0 129.0;
47+
13.0 59.0 109.0 159.0 136.0;
48+
10.0 40.0 70.0 100.0 80.0
49+
])
50+
3251
@test dropdims(conv(x, w; dilation=2), dims = (3,4)) == [
3352
48 98;
3453
58 108;
@@ -157,10 +176,16 @@ end
157176
1150.0 1330.0 1510.0]
158177
@test dropdims(conv(x, w), dims = (4,5)) == res
159178

179+
@test dropdims(conv(Float32.(x), Float32.(w)), dims = (4,5)) == Float32.(res)
180+
160181
@test dropdims(conv(x, w; stride=2), dims = (3,4,5)) == [
161182
322.0 682.0;
162183
394.0 754.0]
163184

185+
@test dropdims(conv(Float32.(x), Float32.(w); stride=2), dims = (3,4,5)) == Float32.([
186+
322.0 682.0;
187+
394.0 754.0])
188+
164189
res = zeros(6,5,4)
165190
res[:, :, 1] = [
166191
1.0 9.0 29.0 49.0 48.0;
@@ -192,12 +217,19 @@ end
192217
270.0 660.0 730.0 800.0 480.0]
193218
@test dropdims(conv(x, w; pad=1), dims = (4,5)) == res
194219

220+
@test dropdims(conv(Float32.(x), Float32.(w); pad=1), dims = (4,5)) == Float32.(res)
221+
195222
@test dropdims(conv(x, w; dilation=2), dims = (3,4,5)) == [
196223
608 788;
197224
644 824;
198225
680 860.
199226
]
200227

228+
@test dropdims(conv(Float32.(x), Float32.(w); dilation=2), dims = (3,4,5)) == Float32.([
229+
608 788;
230+
644 824;
231+
680 860.
232+
])
201233
# NaN tests for dilation forward pass
202234

203235
ys = []

0 commit comments

Comments
 (0)