Skip to content

Commit 8546f3c

Browse files
authored
Merge pull request #54 from tejank10/conv_transpose
2D Transpose Convolutions
2 parents 1940455 + 0d27d79 commit 8546f3c

File tree

3 files changed

+116
-62
lines changed

3 files changed

+116
-62
lines changed

src/conv.jl

Lines changed: 73 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,74 @@ function cdims(x::NTuple{N}, w::NTuple{N}, pad, stride) where N
1717
end
1818
end
1919

20+
21+
# Conv Transpose dims
22+
23+
function ctdims(x::NTuple{N}, w::NTuple{N}, pad, stride, dilation) where N
24+
ntuple(Val(N)) do i
25+
if i < N-1
26+
(x[i] - 1) * stride[i] + dilation[i] * (w[i] - 1) - 2*pad[i] + 1
27+
elseif i == N-1
28+
w[N-1]
29+
else # i == N
30+
x[N]
31+
end
32+
end
33+
end
34+
35+
36+
# Kernel dims
37+
38+
function wdims(x::NTuple{N}, y::NTuple{N}, pad, stride, dilation) where N
39+
ntuple(Val(N)) do i
40+
if i < N-1
41+
1 + div((1 - y[i]) * stride[i] + x[i] + 2pad[i] - 1, dilation[i])
42+
elseif i == N-1
43+
x[i]
44+
else # i == N
45+
y[i-1]
46+
end
47+
end
48+
end
49+
2050
# Interface
2151

2252
head(x) = reverse(Base.tail(reverse(x)))
2353
padtuple(x::Tuple,p::Integer) = map(_->p, head(head(x)))
2454
padtuple(x::Tuple,p::Tuple) = p
2555
padtuple(x::AbstractArray,p) = padtuple(size(x),p)
2656

27-
function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
57+
function conv(x::A, w::A; size=nothing, pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
2858
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
29-
conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
30-
x, w, pad = pad_, stride = stride_, dilation = dilation)
59+
if size === nothing
60+
size = cdims(Base.size(x), dilation_dims(w, dilation), pad_, stride_)
61+
end
62+
conv!(similar(x, size), x, w, pad = pad_, stride = stride_, dilation = dilation)
3163
end
3264

33-
function crosscor(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
65+
function crosscor(x::A, w::A; size=nothing, pad = 0, stride = 1, dilation = 1) where A<:AbstractArray
3466
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
35-
crosscor!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)),
36-
x, w, pad = pad_, stride = stride_, dilation = dilation)
67+
if size === nothing
68+
size = cdims(Base.size(x), dilation_dims(w, dilation), pad_, stride_)
69+
end
70+
crosscor!(similar(x, size), x, w, pad = pad_, stride = stride_, dilation = dilation)
3771
end
3872

39-
∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, flipkernel = 0) where A<:AbstractArray =
40-
∇conv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel)
73+
function ∇conv_data(dy::A, w::A; size=nothing, pad = 0, stride = 1, dilation = 1, flipkernel = 0) where A<:AbstractArray
74+
pad_, stride_, dilation_ = padtuple(dy, pad), padtuple(dy, stride), padtuple(dy, dilation)
75+
if size === nothing
76+
size = ctdims(Base.size(dy), Base.size(w), pad_, stride_, dilation_)
77+
end
78+
∇conv_data!(similar(dy, size), dy, w, pad = pad_, stride = stride_, dilation = dilation_, flipkernel=flipkernel)
79+
end
4180

42-
∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, flipkernel=0) where A<:AbstractArray =
43-
∇conv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel)
81+
function ∇conv_filter(dy::A, x::A; size = nothing, pad = 0, stride = 1, dilation = 1, flipkernel=0) where A<:AbstractArray
82+
pad_, stride_, dilation_ = padtuple(dy, pad), padtuple(dy, stride), padtuple(dy, dilation)
83+
if size === nothing
84+
size = wdims(Base.size(x), Base.size(dy), pad_, stride_, dilation_)
85+
end
86+
∇conv_filter!(zero(similar(dy, size)), dy, x; pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel)
87+
end
4488

4589
# N-D dispatch
4690

@@ -56,18 +100,16 @@ function crosscor!(y::AbstractArray, x::AbstractArray, w::AbstractArray;
56100
conv!(y, x, w, pad=pad, stride=stride, dilation=dilation, flipkernel=1)
57101
end
58102

59-
function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
60-
x::AbstractArray{T,3}, w::AbstractArray{T,3};
103+
function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3}, x::AbstractArray{T,3};
61104
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T
62-
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dw, dy, x, w))
105+
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dw, dy, x))
63106
∇conv_filter!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), flipkernel=flipkernel)
64107
return dw
65108
end
66109

67-
function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
68-
x::AbstractArray{T,3}, w::AbstractArray{T,3};
69-
pad = 0, stride = 1, dilation = 1, flipkernel = 0) where T
70-
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dx, dy, x, w))
110+
function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3}, w::AbstractArray{T,3};
111+
pad = 0, stride = 1, dilation = 1, flipkernel = 0) where T
112+
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dx, dy, w))
71113
∇conv_data!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation..., 1), flipkernel = flipkernel)
72114
return dx
73115
end
@@ -76,25 +118,25 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
76118
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
77119
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
78120

79-
∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
121+
∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4};
80122
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
81-
conv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
123+
conv2d_grad_w!(dw, x, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
82124

83-
∇conv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
125+
∇conv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, w::AbstractArray{T,4};
84126
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
85-
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
127+
conv2d_grad_x!(dx, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
86128

87129
conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
88130
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
89131
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
90132

91-
∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
133+
∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5};
92134
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
93-
conv3d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
135+
conv3d_grad_w!(dw, x, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
94136

95-
∇conv_data!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
137+
∇conv_data!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, w::AbstractArray{T,5};
96138
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
97-
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
139+
conv3d_grad_x!(dx, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
98140

99141
# Depthwise Conv
100142

@@ -216,3 +258,9 @@ meanpool_cpu!(y::AbstractArray{<:Real,5}, x::AbstractArray{<:Real,5}, k::Dims{3}
216258
k::Dims{3}; pad = (0,0), stride = k) =
217259
meanpool3d_grad!(dx, dy, y, x,
218260
window = k, padding = pad, stride = stride)
261+
262+
# Deprecated
263+
264+
# 0.4.2
265+
@deprecate ∇conv_data(dy::A, x::A, w::A; kw...) where A<:AbstractArray ∇conv_data(dy, w; size=size(x), kw...)
266+
@deprecate ∇conv_filter(dy::A, x::A, w::A; kw...) where A<:AbstractArray ∇conv_filter(dy, x; size=size(w), kw...)

src/impl/conv.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -278,15 +278,15 @@ function conv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{
278278
return y
279279
end
280280

281-
function conv2d_grad_w!(dw::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}, dy::AbstractArray{T,4};
281+
function conv2d_grad_w!(dw::AbstractArray{T,4}, x::AbstractArray{T,4}, dy::AbstractArray{T,4};
282282
padding=0, stride=1, dilation=1, mode=0, alpha=1) where T
283283
# dw = x'*dy
284284
Wx,Hx,Cx,Nx = size(x)
285-
Ww,Hw,C1,C2 = size(w)
285+
Ww,Hw,C1,C2 = size(dw)
286286
Wy,Hy,Cy,Ny = size(dy)
287287
# if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
288-
# @assert Cx==C1 && Cy==C2 && Ny==Nx
289-
x2dims = im2col_dims(w,dy)
288+
@assert Cx==C1 && Cy==C2 && Ny==Nx
289+
x2dims = im2col_dims(dw,dy)
290290
x2 = similar(x, x2dims)
291291
# op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix.
292292
Y,M,N,K = Wy*Hy*Cy,Ww*Hw*Cx,Cy,Wy*Hy
@@ -296,29 +296,29 @@ function conv2d_grad_w!(dw::AbstractArray{T,4}, x::AbstractArray{T,4}, w::Abstra
296296
(d1,d2) = psize(dilation,x)
297297
dyi = 1
298298
@inbounds for n in 1:Nx
299-
im2col2d!(w, x, x2, n, p1, p2, s1, s2, d1, d2, mode)
299+
im2col2d!(dw, x, x2, n, p1, p2, s1, s2, d1, d2, mode)
300300
gemm!('T','N',M,N,K,alpha,pointer(x2),pointer(dy,dyi),beta,pointer(dw))
301301
dyi += Y
302302
end
303303
return dw
304304
end
305305

306-
function conv2d_grad_x!(dx::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}, dy::AbstractArray{T,4};
306+
function conv2d_grad_x!(dx::AbstractArray{T,4}, w::AbstractArray{T,4}, dy::AbstractArray{T,4};
307307
padding=0, stride=1, dilation=1, mode=0, alpha=1) where T
308308
# dx = dy*w'
309-
Wx,Hx,Cx,Nx = size(x)
309+
Wx,Hx,Cx,Nx = size(dx)
310310
Ww,Hw,C1,C2 = size(w)
311311
Wy,Hy,Cy,Ny = size(dy)
312312
# if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
313313
@assert Cx==C1 && Cy==C2 && Ny==Nx
314314
x2dims = im2col_dims(w,dy)
315-
x2 = similar(x, x2dims)
315+
x2 = similar(dx, x2dims)
316316
# op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix.
317317
Y,M,N,K = Wy*Hy*Cy,Wy*Hy,Ww*Hw*Cx,Cy
318318
alpha,beta = T(alpha),T(0)
319-
(p1,p2) = psize(padding,x)
320-
(s1,s2) = psize(stride,x)
321-
(d1,d2) = psize(dilation,x)
319+
(p1,p2) = psize(padding,dx)
320+
(s1,s2) = psize(stride,dx)
321+
(d1,d2) = psize(dilation,dx)
322322
dyi = 1
323323
@inbounds for n in 1:Nx
324324
gemm!('N','T',M,N,K,alpha,pointer(dy,dyi),pointer(w),beta,pointer(x2))
@@ -352,7 +352,7 @@ function col2im2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T,
352352
Ww,Hw,C1,C2 = w
353353
xn = x[:, :, :, n]
354354
col2im_2d!(x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1,1,mode)
355-
x[:, :, :, n] = xn
355+
x[:, :, :, n] .= xn
356356
return x
357357
end
358358

@@ -362,7 +362,7 @@ function col2im2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArr
362362
Ww,Hw,C1,C2 = size(w)
363363
xn = x[:, :, :, n]
364364
col2im_2d!(x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,d1,d2,mode)
365-
x[:, :, :, n] = xn
365+
x[:, :, :, n] .= xn
366366
return x
367367
end
368368

@@ -390,15 +390,15 @@ function conv3d!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{
390390
return y
391391
end
392392

393-
function conv3d_grad_w!(dw::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}, dy::AbstractArray{T,5};
393+
function conv3d_grad_w!(dw::AbstractArray{T,5}, x::AbstractArray{T,5}, dy::AbstractArray{T,5};
394394
padding=0, stride=1, dilation = 1, mode=0, alpha=1) where T
395395
# dw = x'*dy
396396
Wx,Hx,Dx,Cx,Nx = size(x)
397-
Ww,Hw,Dw,C1,C2 = size(w)
397+
Ww,Hw,Dw,C1,C2 = size(dw)
398398
Wy,Hy,Dy,Cy,Ny = size(dy)
399399
# if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
400-
# @assert Cx==C1 && Cy==C2 && Ny==Nx
401-
x2dims = im2col_dims(w,dy)
400+
@assert Cx==C1 && Cy==C2 && Ny==Nx
401+
x2dims = im2col_dims(dw,dy)
402402
x2 = similar(x, x2dims)
403403
# op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix.
404404
Y,M,N,K = Wy*Hy*Dy*Cy,Ww*Hw*Dw*Cx,Cy,Wy*Hy*Dy
@@ -408,29 +408,29 @@ function conv3d_grad_w!(dw::AbstractArray{T,5}, x::AbstractArray{T,5}, w::Abstra
408408
(d1,d2,d3) = psize(dilation,x)
409409
dyi = 1
410410
@inbounds for n in 1:Nx
411-
im2col3d!(w, x, x2, n, p1, p2, p3, s1, s2, s3, d1, d2, d3, mode)
411+
im2col3d!(dw, x, x2, n, p1, p2, p3, s1, s2, s3, d1, d2, d3, mode)
412412
gemm!('T','N',M,N,K,alpha,pointer(x2),pointer(dy,dyi),beta,pointer(dw))
413413
dyi += Y
414414
end
415415
return dw
416416
end
417417

418-
function conv3d_grad_x!(dx::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}, dy::AbstractArray{T,5};
418+
function conv3d_grad_x!(dx::AbstractArray{T,5}, w::AbstractArray{T,5}, dy::AbstractArray{T,5};
419419
padding=0, stride=1, dilation = 1, mode=0, alpha=1) where T
420420
# dx = dy*w'
421-
Wx,Hx,Dx,Cx,Nx = size(x)
421+
Wx,Hx,Dx,Cx,Nx = size(dx)
422422
Ww,Hw,Dw,C1,C2 = size(w)
423423
Wy,Hy,Dy,Cy,Ny = size(dy)
424424
# if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
425425
@assert Cx==C1 && Cy==C2 && Ny==Nx
426426
x2dims = im2col_dims(w,dy)
427-
x2 = similar(x, x2dims)
427+
x2 = similar(dx, x2dims)
428428
# op(A) is an m-by-k matrix, op(B) is a k-by-n matrix, C is an m-by-n matrix.
429429
Y,M,N,K = Wy*Hy*Dy*Cy,Wy*Hy*Dy,Ww*Hw*Dw*Cx,Cy
430430
alpha,beta = T(alpha),T(0)
431-
(p1,p2,p3) = psize(padding,x)
432-
(s1,s2,s3) = psize(stride,x)
433-
(d1,d2,d3) = psize(dilation,x)
431+
(p1,p2,p3) = psize(padding,dx)
432+
(s1,s2,s3) = psize(stride,dx)
433+
(d1,d2,d3) = psize(dilation,dx)
434434
dyi = 1
435435
@inbounds for n in 1:Nx
436436
gemm!('N','T',M,N,K,alpha,pointer(dy,dyi),pointer(w),beta,pointer(x2))

test/conv.jl

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwiseconv, ∇depthwiseconv_filter, ∇depthwiseconv_data
1+
using NNlib: conv, crosscor, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwiseconv, ∇depthwiseconv_filter, ∇depthwiseconv_data
22

33
@testset "conv2d" begin
44
x = reshape(Float64[1:20;], 5, 4, 1, 1)
@@ -10,6 +10,12 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwisec
1010
49 99 149;
1111
59 109 159.]
1212

13+
@test dropdims(crosscor(x, w), dims = (3,4)) == [
14+
51 101 151;
15+
61 111 161;
16+
71 121 171;
17+
81 131 181.]
18+
1319
@test dropdims(conv(Float32.(x), Float32.(w)), dims=(3,4)) == Float32.([
1420
29 79 129;
1521
39 89 139;
@@ -59,26 +65,26 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwisec
5965
# correctness of gradients is cross-checked with CUDNN.jl
6066
# (it's assumed convolution code won't change often)
6167

62-
@test size(∇conv_filter(reshape(rand(4,3), 4, 3, 1, 1), x, w)) == size(w)
63-
@test size(∇conv_data(reshape(rand(4,3), 4, 3, 1, 1), x, w)) == size(x)
68+
@test size(∇conv_filter(reshape(rand(4,3), 4, 3, 1, 1), x)) == size(w)
69+
@test size(∇conv_data(reshape(rand(4,3), 4, 3, 1, 1), w)) == size(x)
6470

6571
# Test that stride/pad work backward as well
6672
y = conv(x, w; stride=2, pad=1, dilation=2)
6773
@test size(y) == (3, 2, 1, 1)
68-
@test size(∇conv_filter(y, x, w; stride=2, pad=1, dilation=2)) == size(w)
69-
@test size(∇conv_data(y, x, w; stride=2, pad=1, dilation=2)) == size(x)
74+
@test size(∇conv_filter(y, x; size=size(w), stride=2, pad=1, dilation=2)) == size(w)
75+
@test size(∇conv_data(y, w; size=size(x), stride=2, pad=1, dilation=2)) == size(x)
7076

7177
# NaN tests for dilation backward pass: filters
7278
dy = randn(size(ys[1]))
7379
dws = []
7480
for idx in 1:1000
75-
push!(dws, ∇conv_filter(dy, x, w; dilation=2))
81+
push!(dws, ∇conv_filter(dy, x; size=size(w), dilation=2))
7682
end
7783

7884
# NaN tests for dilation backward pass: input
7985
dxs = []
8086
for idx in 1:1000
81-
push!(dxs, ∇conv_data(dy, x, w; dilation=2))
87+
push!(dxs, ∇conv_data(dy, w; size=size(x), dilation=2))
8288
end
8389

8490
@test !any([any(isnan.(dws[idx])) for idx in 1:1000])
@@ -107,7 +113,7 @@ end
107113
X = copy(x[:,:,i:i,:]);
108114
W = copy(permutedims(w[:,:,:,i:i],[1,2,4,3]));
109115
DY = copy(dy[:,:,2i-1:2i,:]);
110-
res = ∇conv_data(DY,X,W)
116+
res = ∇conv_data(DY,W;size=size(X))
111117
@test dropdims(z[:,:,i:i,:], dims=(3,4)) == dropdims(res, dims=(3,4))
112118
end
113119

@@ -116,7 +122,7 @@ end
116122
X = copy(x[:,:,i:i,:]);
117123
W = copy(permutedims(w[:,:,:,i:i],[1,2,4,3]))
118124
DY = copy(dy[:,:,2i-1:2i,:])
119-
res = ∇conv_filter(DY,X,W)
125+
res = ∇conv_filter(DY,X; size=size(W))
120126
@test dropdims(z[:,:,:,i:i]; dims=(4)) == dropdims(res; dims=(3))
121127
end
122128

@@ -236,20 +242,20 @@ end
236242
# correctness of gradients is cross-checked with CUDNN.jl
237243
# (it's assumed convolution code won't change often)
238244

239-
@test size(∇conv_filter(reshape(rand(4,3,2), 4, 3, 2, 1, 1), x, w)) == size(w)
240-
@test size(∇conv_data(reshape(rand(4,3,2), 4, 3, 2, 1, 1), x, w)) == size(x)
245+
@test size(∇conv_filter(reshape(rand(4,3,2), 4, 3, 2, 1, 1), x; size=size(w))) == size(w)
246+
@test size(∇conv_data(reshape(rand(4,3,2), 4, 3, 2, 1, 1), w; size=size(x))) == size(x)
241247

242248
# NaN tests for dilation backward pass: filters
243249
dy = randn(size(ys[1]))
244250
dws = []
245251
for idx in 1:1000
246-
push!(dws, ∇conv_filter(dy, x, w; dilation=2))
252+
push!(dws, ∇conv_filter(dy, x; size=size(w), dilation=2))
247253
end
248254

249255
# NaN tests for dilation backward pass: input
250256
dxs = []
251257
for idx in 1:1000
252-
push!(dxs, ∇conv_data(dy, x, w; dilation=2))
258+
push!(dxs, ∇conv_data(dy, w; size=size(x), dilation=2))
253259
end
254260

255261
@test !any([any(isnan.(dws[idx])) for idx in 1:1000])

0 commit comments

Comments
 (0)