Skip to content

Commit b919324

Browse files
authored
Merge pull request #57 from avik-pal/patch1
Fix Depthwise Convolutions
2 parents 7f90f38 + 2f2bea7 commit b919324

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

src/impl/conv.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ function im2col_dims(w::NTuple{4, Int}, y)
182182
end
183183

184184
function depthwiseconv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
185-
padding = 0, stride = 1, mode = 1, alpha = T(1)) where T
185+
padding = 0, stride = 1, mode = 0, alpha = T(1)) where T
186186
Wx,Hx,Cx,Nx = size(x)
187187
Ww,Hw,Cm,Cw = size(w) # Cm = Channel Multiplier
188188
@assert Cx == Cw DimensionMismatch()
@@ -223,7 +223,7 @@ function depthwiseconv2d_grad_w!(dw::AbstractArray{T,4}, x::AbstractArray{T,4},
223223
im2col2d!(dims_w, x, x2, i, p1, p2, s1, s2, mode)
224224
dwidx = 1
225225
@inbounds for j in 1:Cx
226-
gemm!('T','T',M,N,K,alpha,pointer(x2,(j-1)*M*K+1),pointer(dy,dyidx+(j-1)*K*N),beta,pointer(dw,dwidx))
226+
gemm!('T','N',M,N,K,alpha,pointer(x2,(j-1)*M*K+1),pointer(dy,dyidx+(j-1)*K*N),beta,pointer(dw,dwidx))
227227
dwidx += W
228228
end
229229
dyidx += Y
@@ -248,7 +248,7 @@ function depthwiseconv2d_grad_x!(dx::AbstractArray{T,4}, x::AbstractArray{T,4},
248248
dyidx = 1
249249
@inbounds for i in 1:Nx
250250
@inbounds for j in 1:Cx
251-
gemm!('N','T',M,N,K,alpha,pointer(dy,dyidx+(j-1)*K*M),pointer(w,(j-1)*K*N+1),beta,pointer(x2,(j-1)*M*N+1))
251+
gemm!('N','T',M,N,K,alpha,pointer(dy,dyidx+(j-1)*K*M),pointer(w,(j-1)*K*N+1),beta,pointer(x2,(j-1)*M*N+1))
252252
end
253253
col2im2d!(dims_w,dx,x2,i,p1,p2,s1,s2,mode)
254254
dyidx += Y

test/conv.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,25 +69,25 @@ end
6969

7070
@testset "depthwiseconv2d" begin
7171
x = reshape(Float64[1:18;], 3, 3, 2, 1)
72-
w = reshape(Float64[1:8;], 2, 2, 1, 2)
72+
w = reshape(Float64[1:16;], 2, 2, 2, 2)
7373

74-
@test depthwiseconv(x, w)[:] == [37.0, 47.0, 67.0, 77.0, 319.0, 345.0, 397.0, 423.0]
74+
@test depthwiseconv(x, w)[:] == [23.0, 33.0, 53.0, 63.0, 71.0, 97.0, 149.0, 175.0, 497.0, 539.0, 623.0, 665.0, 689.0, 747.0, 863.0, 921.0]
7575

76-
@test depthwiseconv(x, w, stride = 2, pad = 1)[:] == [4.0, 18.0, 36.0, 77.0, 80.0, 173.0, 206.0, 423.0]
76+
@test depthwiseconv(x, w, stride = 2, pad = 1)[:] == [1.0, 7.0, 19.0, 63.0, 5.0, 27.0, 63.0, 175.0, 90.0, 218.0, 287.0, 665.0, 130.0, 310.0, 403.0, 921.0]
7777

78-
@test depthwiseconv(x, w, stride = 2)[:] == [37.0, 319.0]
78+
@test depthwiseconv(x, w, stride = 2)[:] == [23.0, 71.0, 497.0, 689.0]
7979

80-
@test depthwiseconv(x, w, pad = 1)[:] == [4.0, 11.0, 18.0, 9.0, 18.0, 37.0, 47.0, 21.0, 36.0, 67.0, 77.0, 33.0, 14.0, 23.0, 26.0, 9.0, 80.0, 158.0, 173.0, 84.0, 164.0, 319.0, 345.0, 165.0, 206.0, 397.0, 423.0, 201.0, 96.0, 182.0, 193.0, 90.0]
80+
@test depthwiseconv(x, w, pad = 1)[:] == [1.0, 4.0, 7.0, 6.0, 7.0, 23.0, 33.0, 24.0, 19.0, 53.0, 63.0, 42.0, 21.0, 52.0, 59.0, 36.0, 5.0, 16.0, 27.0, 18.0, 27.0, 71.0, 97.0, 60.0, 63.0, 149.0, 175.0, 102.0, 49.0, 112.0, 127.0, 72.0, 90.0, 199.0, 218.0, 120.0, 227.0, 497.0, 539.0, 294.0, 287.0, 623.0, 665.0, 360.0, 176.0, 379.0, 402.0, 216.0, 130.0, 283.0, 310.0, 168.0, 319.0, 689.0, 747.0, 402.0, 403.0, 863.0, 921.0, 492.0, 240.0, 511.0, 542.0, 288.0]
8181

8282
# the correctness of the gradients are being verified by calling
8383
# the corresponding counvolution gradients
8484

85-
dy = reshape(Float64[1:8;], 2,2,2,1)
85+
dy = reshape(Float64[1:16;], 2,2,4,1)
8686
local z = ∇depthwiseconv_data(dy,x,w)
8787
for i in 1:2
8888
X = copy(x[:,:,i:i,:]);
8989
W = copy(permutedims(w[:,:,:,i:i],[1,2,4,3]));
90-
DY = copy(dy[:,:,i:i,:]);
90+
DY = copy(dy[:,:,2i-1:2i,:]);
9191
res = ∇conv_data(DY,X,W)
9292
@test squeeze(z[:,:,i:i,:], (3,4)) == squeeze(res, (3,4))
9393
end
@@ -96,17 +96,17 @@ end
9696
for i in 1:2
9797
X = copy(x[:,:,i:i,:]);
9898
W = copy(permutedims(w[:,:,:,i:i],[1,2,4,3]))
99-
DY = copy(dy[:,:,i:i,:])
99+
DY = copy(dy[:,:,2i-1:2i,:])
100100
res = ∇conv_filter(DY,X,W)
101-
@test squeeze(z[:,:,:,i:i], (3,4)) == squeeze(res, (3,4))
101+
@test squeeze(z[:,:,:,i:i], (4)) == squeeze(res, (3))
102102
end
103103

104-
@test size(∇depthwiseconv_filter(rand(2,2,2,1), x, w)) == size(w)
105-
@test size(∇depthwiseconv_data(rand(2,2,2,1), x, w)) == size(x)
104+
@test size(∇depthwiseconv_filter(rand(2,2,4,1), x, w)) == size(w)
105+
@test size(∇depthwiseconv_data(rand(2,2,4,1), x, w)) == size(x)
106106

107107
# Test for the stride/pad for backward pass
108108
y = depthwiseconv(x,w,stride=2,pad=1)
109-
@test size(y) == (2,2,2,1)
109+
@test size(y) == (2,2,4,1)
110110
@test size(∇depthwiseconv_filter(rand(size(y)), x, w, stride=2, pad=1)) == size(w)
111111
@test size(∇depthwiseconv_data(rand(size(y)), x, w, stride=2, pad=1)) == size(x)
112112
end

0 commit comments

Comments
 (0)