Skip to content

Commit f0d7839

Browse files
authored
Merge pull request #42 from avik-pal/depthwiseconv
Depthwise Convolutions
2 parents 23155cf + ad10923 commit f0d7839

File tree

3 files changed

+178
-1
lines changed

3 files changed

+178
-1
lines changed

src/conv.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,35 @@ conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
8585
pad = 0, stride = 1, dilation = 1) where T =
8686
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
8787

88+
# Depthwise Conv
89+
90+
function dcdims(x::NTuple{4,Int}, w::NTuple{4,Int}, pad, stride)
91+
((x[1] + 2 * pad[1] - w[1])÷stride[1] + 1,(x[2] + 2 * pad[2] - w[2])÷stride[2] + 1,w[3]*w[4],x[4])
92+
end
93+
94+
function depthwiseconv(x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray
95+
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
96+
depthwiseconv!(similar(x, dcdims(size(x), size(w), pad_, stride_)), x, w, pad = pad_, stride = stride_)
97+
end
98+
99+
depthwiseconv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
100+
pad = 0, stride = 1) where T =
101+
depthwiseconv2d!(y, x, w, padding = pad, stride = stride)
102+
103+
∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
104+
∇depthwiseconv_data!(zeros(x), dy, x, w; pad = pad, stride = stride)
105+
106+
∇depthwiseconv_filter(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
107+
∇depthwiseconv_filter!(zeros(w), dy, x, w; pad = pad, stride = stride)
108+
109+
∇depthwiseconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
110+
pad = 0, stride = 1) where T =
111+
depthwiseconv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride)
112+
113+
∇depthwiseconv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
114+
pad = 0, stride = 1) where T =
115+
depthwiseconv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride)
116+
88117
# Pooling
89118

90119
function pdims(dims::Dims{N}, window, padding, stride) where N

src/impl/conv.jl

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,92 @@ function im2col_dims(w,y)
170170
return (r, c)
171171
end
172172

173+
function im2col_dims(w::NTuple{4, Int}, y)
174+
N = ndims(y)
175+
r,c = 1,1
176+
for i=1:N-2
177+
r *= size(y,i)
178+
c *= w[i]
179+
end
180+
c *= w[N-1]
181+
return (r, c)
182+
end
183+
184+
function depthwiseconv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
185+
padding = 0, stride = 1, mode = 1, alpha = T(1)) where T
186+
Wx,Hx,Cx,Nx = size(x)
187+
Ww,Hw,Cm,Cw = size(w) # Cm = Channel Multiplier
188+
@assert Cx == Cw DimensionMismatch()
189+
Wy,Hy,Cy,Ny = size(y) # Cy = Cw * Cm
190+
dims_w = (Ww,Hw,Cw,Cm*Cw)
191+
x2dims = im2col_dims(dims_w,y)
192+
x2 = similar(x, x2dims)
193+
(p1,p2) = psize(padding,x)
194+
(s1,s2) = psize(stride,x)
195+
M,N,K,Y = Wy*Hy,Cm,Ww*Hw,Wy*Hy*Cm
196+
yidx = 1
197+
@inbounds for i in 1:Nx
198+
im2col2d!(dims_w, x, x2, i, p1, p2, s1, s2, mode)
199+
@inbounds for j in 1:Cx
200+
gemm!('N','N',M,N,K,alpha,pointer(x2,(j-1)*M*K+1),pointer(w,(j-1)*K*N+1),T(0),pointer(y,yidx))
201+
yidx += Y
202+
end
203+
end
204+
return y
205+
end
206+
207+
function depthwiseconv2d_grad_w!(dw::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}, dy::AbstractArray{T,4};
208+
padding=0, stride=1, mode=0, alpha=1) where T
209+
Wx,Hx,Cx,Nx = size(x)
210+
Ww,Hw,Cm,Cw = size(w) # Cm = Channel Multiplier
211+
@assert Cx == Cw DimensionMismatch()
212+
Wy,Hy,Cy,Ny = size(dy) # Cy = Cw * Cm
213+
@assert Cy == Cw * Cm DimensionMismatch()
214+
dims_w = (Ww,Hw,Cw,Cm*Cw)
215+
x2dims = im2col_dims(dims_w,dy)
216+
x2 = similar(x, x2dims)
217+
(p1,p2) = psize(padding,x)
218+
(s1,s2) = psize(stride,x)
219+
M,N,K,Y,W = Ww*Hw,Cm,Wy*Hy,Wy*Hy*Cm*Cx,Ww*Hw*Cm
220+
alpha,beta = T(alpha),T(1)
221+
dyidx = 1
222+
@inbounds for i in 1:Nx
223+
im2col2d!(dims_w, x, x2, i, p1, p2, s1, s2, mode)
224+
dwidx = 1
225+
@inbounds for j in 1:Cx
226+
gemm!('T','T',M,N,K,alpha,pointer(x2,(j-1)*M*K+1),pointer(dy,dyidx+(j-1)*K*N),beta,pointer(dw,dwidx))
227+
dwidx += W
228+
end
229+
dyidx += Y
230+
end
231+
return dw
232+
end
233+
234+
function depthwiseconv2d_grad_x!(dx::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}, dy::AbstractArray{T,4};
235+
padding=0, stride=1, mode=0, alpha=1) where T
236+
Wx,Hx,Cx,Nx = size(x)
237+
Ww,Hw,Cm,Cw = size(w) # Cm = Channel Multiplier
238+
@assert Cx == Cw DimensionMismatch()
239+
Wy,Hy,Cy,Ny = size(dy) # Cy = Cw * Cm
240+
@assert Cy == Cw * Cm DimensionMismatch()
241+
dims_w = (Ww,Hw,Cw,Cm*Cw)
242+
x2dims = im2col_dims(dims_w,dy)
243+
x2 = similar(x, x2dims)
244+
M,N,K,Y,W = Wy*Hy,Ww*Hw,Cm,Wy*Hy*Cm*Cx,Ww*Hw*Cm
245+
alpha,beta = T(alpha),T(0)
246+
(p1,p2) = psize(padding,x)
247+
(s1,s2) = psize(stride,x)
248+
dyidx = 1
249+
@inbounds for i in 1:Nx
250+
@inbounds for j in 1:Cx
251+
gemm!('N','T',M,N,K,alpha,pointer(dy,dyidx+(j-1)*K*M),pointer(w,(j-1)*K*N+1),beta,pointer(x2,(j-1)*M*N+1))
252+
end
253+
col2im2d!(dims_w,dx,x2,i,p1,p2,s1,s2,mode)
254+
dyidx += Y
255+
end
256+
return dx
257+
end
258+
173259
function conv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
174260
padding=0, stride=1, dilation=1, mode=0, alpha=T(1)) where T
175261
if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
@@ -242,6 +328,15 @@ function conv2d_grad_x!(dx::AbstractArray{T,4}, x::AbstractArray{T,4}, w::Abstra
242328
return dx
243329
end
244330

331+
function im2col2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
332+
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, mode::Int) where T
333+
Wx,Hx,Cx,Nx = size(x)
334+
Ww,Hw,C1,C2 = w
335+
xn = x[:, :, :, n]
336+
im2col_2d!(xn,x2,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1,1,mode)
337+
return x2
338+
end
339+
245340
function im2col2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
246341
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, d1::Int, d2::Int, mode::Int) where T
247342
Wx,Hx,Cx,Nx = size(x)
@@ -251,6 +346,16 @@ function im2col2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArr
251346
return x2
252347
end
253348

349+
function col2im2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
350+
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, mode::Int) where T
351+
Wx,Hx,Cx,Nx = size(x)
352+
Ww,Hw,C1,C2 = w
353+
xn = x[:, :, :, n]
354+
col2im_2d!(x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1,1,mode)
355+
x[:, :, :, n] = xn
356+
return x
357+
end
358+
254359
function col2im2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
255360
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, d1::Int, d2::Int, mode::Int) where T
256361
Wx,Hx,Cx,Nx = size(x)

test/conv.jl

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool
1+
using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool, depthwiseconv, ∇depthwiseconv_filter, ∇depthwiseconv_data
22

33
@testset "conv2d" begin
44
x = reshape(Float64[1:20;], 5, 4, 1, 1)
@@ -67,6 +67,49 @@ using NNlib: conv, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool
6767

6868
end
6969

70+
@testset "depthwiseconv2d" begin
71+
x = reshape(Float64[1:18;], 3, 3, 2, 1)
72+
w = reshape(Float64[1:8;], 2, 2, 1, 2)
73+
74+
@test depthwiseconv(x, w)[:] == [37.0, 47.0, 67.0, 77.0, 319.0, 345.0, 397.0, 423.0]
75+
76+
@test depthwiseconv(x, w, stride = 2, pad = 1)[:] == [4.0, 18.0, 36.0, 77.0, 80.0, 173.0, 206.0, 423.0]
77+
78+
@test depthwiseconv(x, w, stride = 2)[:] == [37.0, 319.0]
79+
80+
@test depthwiseconv(x, w, pad = 1)[:] == [4.0, 11.0, 18.0, 9.0, 18.0, 37.0, 47.0, 21.0, 36.0, 67.0, 77.0, 33.0, 14.0, 23.0, 26.0, 9.0, 80.0, 158.0, 173.0, 84.0, 164.0, 319.0, 345.0, 165.0, 206.0, 397.0, 423.0, 201.0, 96.0, 182.0, 193.0, 90.0]
81+
82+
# the correctness of the gradients are being verified by calling
83+
# the corresponding counvolution gradients
84+
85+
dy = reshape(Float64[1:8;], 2,2,2,1)
86+
local z = ∇depthwiseconv_data(dy,x,w)
87+
for i in 1:2
88+
X = copy(x[:,:,i:i,:]);
89+
W = copy(permutedims(w[:,:,:,i:i],[1,2,4,3]));
90+
DY = copy(dy[:,:,i:i,:]);
91+
res = ∇conv_data(DY,X,W)
92+
@test squeeze(z[:,:,i:i,:], (3,4)) == squeeze(res, (3,4))
93+
end
94+
95+
z = ∇depthwiseconv_filter(dy, x, w)
96+
for i in 1:2
97+
X = copy(x[:,:,i:i,:]);
98+
W = copy(permutedims(w[:,:,:,i:i],[1,2,4,3]))
99+
DY = copy(dy[:,:,i:i,:])
100+
res = ∇conv_filter(DY,X,W)
101+
@test squeeze(z[:,:,:,i:i], (3,4)) == squeeze(res, (3,4))
102+
end
103+
104+
@test size(∇depthwiseconv_filter(rand(2,2,2,1), x, w)) == size(w)
105+
@test size(∇depthwiseconv_data(rand(2,2,2,1), x, w)) == size(x)
106+
107+
# Test for the stride/pad for backward pass
108+
y = depthwiseconv(x,w,stride=2,pad=1)
109+
@test size(y) == (2,2,2,1)
110+
@test size(∇depthwiseconv_filter(rand(size(y)), x, w, stride=2, pad=1)) == size(w)
111+
@test size(∇depthwiseconv_data(rand(size(y)), x, w, stride=2, pad=1)) == size(x)
112+
end
70113

71114
@testset "maxpool2d" begin
72115

0 commit comments

Comments
 (0)