Skip to content

Commit 2495c1b

Browse files
author
Avik Pal
committed
Tests pass
1 parent fba6d4a commit 2495c1b

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

src/nnpack/interface.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,11 @@ conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:A
5050
conv(Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
5151

5252
function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4}
53-
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride)
53+
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
5454
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
5555
if fallback
5656
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
5757
else
58-
@warn "Accessed"
5958
conv!(y, x, w, zeros(Float32, size(y, 3)), pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
6059
end
6160
end
@@ -64,12 +63,11 @@ conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) w
6463
conv(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
6564

6665
function conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
67-
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride)
66+
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
6867
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
6968
if fallback
7069
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
7170
else
72-
@warn "Accessed 2.0"
7371
conv!(y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
7472
end
7573
end
@@ -78,7 +76,7 @@ crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0
7876
crosscor(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
7977

8078
function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
81-
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride)
79+
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
8280
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
8381
if fallback
8482
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode = 1)
@@ -88,10 +86,12 @@ function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo =
8886
end
8987

9088
conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
91-
conv(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
89+
conv!(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
9290

9391
function conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
94-
flipkernel == 0 && (w = reverse(reverse(w, dims=1), dims=2))
92+
if flipkernel == 0
93+
w = reverse(reverse(w, dims=1), dims=2)
94+
end
9595
nnp_convolution_output(y, x, w, b, algo = algo, padding = pad, stride = stride, threadpool = shared_threadpool[])
9696
end
9797

@@ -105,9 +105,9 @@ crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo =
105105
∇conv_data(Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
106106

107107
function ∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4}
108-
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride)
108+
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
109109
if fallback
110-
error("Unsupported Operation")
110+
conv2d_grad_x!(zeros(Float32, size(x)), x, w, dy, padding = pad_, stride = stride_, dilation = dilation)
111111
else
112112
∇conv_data!(zeros(Float32, size(x)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
113113
end
@@ -125,9 +125,9 @@ end
125125
∇conv_filter(Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
126126

127127
function ∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4}
128-
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride)
128+
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
129129
if fallback
130-
error("Unsupported Operation")
130+
conv2d_grad_w!(zeros(Float32, size(w)), x, w, dy, padding = pad_, stride = stride_, dilation = dilation)
131131
else
132132
∇conv_filter!(zeros(Float32, size(w)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
133133
end

test/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwisec
5757
@test size(∇conv_data(reshape(rand(4,3), 4, 3, 1, 1), x, w)) == size(x)
5858

5959
# Test that stride/pad work backward as well
60-
y = conv(x, w; stride=2, pad=1, dilation=2)
60+
y = Float64.(conv(x, w; stride=2, pad=1, dilation=2))
6161
@test size(y) == (3, 2, 1, 1)
6262
@test size(∇conv_filter(y, x, w; stride=2, pad=1, dilation=2)) == size(w)
6363
@test size(∇conv_data(y, x, w; stride=2, pad=1, dilation=2)) == size(x)

0 commit comments

Comments
 (0)