Skip to content

Commit 42d9f64

Browse files
authored
Merge pull request #92 from FluxML/sf/typed_im2col
Rewrite `im2col()` for greater performance
2 parents ccc6dad + 3e56f3d commit 42d9f64

File tree

1 file changed

+142
-48
lines changed

1 file changed

+142
-48
lines changed

src/impl/conv.jl

Lines changed: 142 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,112 @@ function psize(p, x)
1010
end
1111
end
1212

13-
function im2col_2d!(img::AbstractArray{T,3}, col::AbstractArray{T,2}, width::Int, height::Int, channels::Int,
14-
kernel_w::Int, kernel_h::Int, pad_w::Int, pad_h::Int, stride_w::Int, stride_h::Int,
15-
dil_w::Int, dil_h::Int, mode::Int) where T
16-
17-
height_col = div(height + 2pad_h - (kernel_h - 1) * dil_h - 1, stride_h) + 1
18-
width_col = div(width + 2pad_w - (kernel_w - 1) * dil_w - 1, stride_w) + 1
19-
channels_col = channels * kernel_h * kernel_w
13+
# Type system-level information about convolution dimensions. Critical for things like
14+
# im2col_2d!() to generate efficient code.
15+
struct ConvDims{img, kernel, channels, stride, padding, dilation, flipkernel} end
16+
img_size(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = I
17+
18+
# Calculate the output dimensions of this convolution
19+
function output_size(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F}
20+
O_w = div(I[1] + P[1] + P[2] - (K[1] - 1) * D[1] - 1, S[1]) + 1
21+
O_h = div(I[2] + P[3] + P[4] - (K[1] - 1) * D[1] - 1, S[1]) + 1
22+
return (O_w, O_h)
23+
end
24+
kernel_size(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = K
25+
img_channels(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = C
26+
stride(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = S
27+
padding(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = P
28+
dilation(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = D
29+
flipkernel(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = F
30+
31+
function im2col_2d!(img::AbstractArray{T,3}, col::AbstractArray{T,2}, cdims::ConvDims) where T
32+
width, height = img_size(cdims)
33+
kernel_w, kernel_h = kernel_size(cdims)
34+
channels = img_channels(cdims)
35+
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi = padding(cdims)
36+
dil_w, dil_h = dilation(cdims)
37+
stride_w, stride_h = stride(cdims)
38+
width_col, height_col = output_size(cdims)
39+
40+
if flipkernel(cdims)
41+
flipk = (w, h) -> (kernel_w - w + 1, kernel_h - h + 1)
42+
else
43+
flipk = (w, h) -> (w, h)
44+
end
2045

21-
#pragma omp parallel for
22-
for c = 1:channels_col
23-
w_offset = (c - 1) % kernel_w
24-
h_offset = div(c - 1, kernel_w) % kernel_h
25-
c_im = div(c - 1, kernel_h * kernel_w)
26-
if mode == 0
27-
w_offset = kernel_w - 1 - w_offset
28-
h_offset = kernel_h - 1 - h_offset
46+
# Reshape col for easy access.
47+
col_reshaped = reshape(col, (width_col, height_col, kernel_w, kernel_h, channels))
48+
49+
# Let us first calculate the number of rows/columns within which we must zero out some
50+
# portion of the image patches we're copying over. Note the subtractions on the `_hi`
51+
# variants are due to us needing to account for padding that is completely ignored due
52+
# to stride/dilation/kernel size combinations.
53+
spill_w_lo = ceil(Int, pad_w_lo/stride_w)
54+
spill_w_hi = width_col - div(width + pad_w_lo - (kernel_w - 1)*dil_w, stride_w)
55+
spill_h_lo = ceil(Int, pad_h_lo/stride_h)
56+
spill_h_hi = height_col - div(height + pad_h_lo - (kernel_h - 1)*dil_h, stride_h)
57+
spill_w_hi_abs = width_col - spill_w_hi + 1
58+
spill_h_hi_abs = height_col - spill_h_hi + 1
59+
60+
# First, a helper function to project from output (w, h) to input (input_w, input_h)
61+
project(idx, stride, pad) = (idx - 1)*stride - pad + 1
62+
63+
# These are the regions we're going to have to run with cognizance of padding
64+
padded_regions = (
65+
(1:width_col, 1:spill_h_lo),
66+
(1:spill_w_lo, (spill_h_lo+1):(spill_h_hi_abs-1)),
67+
(spill_w_hi_abs:width_col, (spill_h_lo+1):(spill_h_hi_abs-1)),
68+
(1:width_col, spill_h_hi_abs:height_col),
69+
)
70+
71+
# We begin by copying the central region of the image which requires no padding at all.
72+
# Eliminating the branches of the fully generalized version below gives us a nice
73+
# speedup on the majority of the data.
74+
for c in 1:channels
75+
for kh in 1:kernel_h
76+
for kw in 1:kernel_w
77+
for h in (spill_h_lo+1):(height_col - spill_h_hi)
78+
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
79+
80+
@inbounds for w in (spill_w_lo+1):(width_col - spill_w_hi)
81+
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
82+
col_reshaped[w, h, flipk(kw, kh)..., c] = img[input_kw, input_kh, c]
83+
end
84+
end
85+
end
2986
end
30-
for h = 1:height_col
31-
for w = 1:width_col
32-
h_pad = (h - 1) * stride_h - pad_h + h_offset * dil_h
33-
w_pad = (w - 1) * stride_w - pad_w + w_offset * dil_w
34-
if h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width
35-
col[((c - 1)*height_col+h-1) * width_col + w] =
36-
img[(c_im * height + h_pad) * width + w_pad + 1]
37-
else
38-
col[((c - 1)*height_col+h - 1) * width_col + w] = 0
87+
end
88+
89+
# For each "padded region", we run the fully general version
90+
for (w_region, h_region) in padded_regions
91+
for c in 1:channels
92+
for kh in 1:kernel_h
93+
for kw in 1:kernel_w
94+
@inbounds for h in h_region
95+
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
96+
97+
# If this column is off the edge, then deal with the entire thing
98+
# in one fell swoop, like a ravenous flock of crows. CAW CAW.
99+
if input_kh <= 0 || input_kh > height
100+
for w in w_region
101+
col_reshaped[w, h, flipk(kw, kh)..., c] = zero(eltype(col_reshaped))
102+
end
103+
continue
104+
end
105+
106+
@inbounds for w in w_region
107+
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
108+
109+
# If this pixel is off the edge of the map, clear it out.
110+
if input_kw <= 0 || input_kw > width
111+
col_reshaped[w, h, flipk(kw, kh)..., c] = zero(eltype(col_reshaped))
112+
continue
113+
end
114+
115+
# Copy the data over
116+
col_reshaped[w, h, flipk(kw, kh)..., c] = img[input_kw, input_kh, c]
117+
end
118+
end
39119
end
40120
end
41121
end
@@ -256,26 +336,41 @@ function depthwiseconv2d_grad_x!(dx::AbstractArray{T,4}, x::AbstractArray{T,4},
256336
return dx
257337
end
258338

339+
function conv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4},
340+
cdims::ConvDims; alpha=T(1)) where T
341+
Wx, Hx = img_size(cdims)
342+
Ww, Hw = kernel_size(cdims)
343+
Wy, Hy = output_size(cdims)
344+
Cx = img_channels(cdims)
345+
M, N, K, Y = Wy*Hy, size(y,4), Ww*Hw*Cx, Wy*Hy*size(y, 4)
346+
347+
x2 = similar(x, im2col_dims(w, y))
348+
@inbounds for n in 1:size(x,4)
349+
im2col_2d!(view(x, :, :, :, n), x2, cdims)
350+
gemm!('N','N',M,N,K,alpha,pointer(x2),pointer(w),T(0),pointer(y,(n - 1)*Y + 1))
351+
end
352+
return y
353+
end
354+
259355
function conv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
260356
padding=0, stride=1, dilation=1, mode=0, alpha=T(1)) where T
261-
if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
357+
if mode != 0 && mode != 1
358+
throw(ArgumentError("conv2d only supports mode=0 or 1."))
359+
end
262360
Wx,Hx,Cx,Nx = size(x)
263361
Ww,Hw,C1,C2 = size(w)
264-
if Cx!=C1; throw(DimensionMismatch()); end
265-
Wy,Hy,Cy,Ny = size(y)
266-
x2dims = im2col_dims(w,y)
267-
x2 = similar(x, x2dims)
362+
363+
# Check that the number of channels in `x` matches the number of channels in each
364+
# kernel of `w`. IF it doesn't, throw a DimensionMismatch()
365+
if Cx != C1
366+
throw(DimensionMismatch())
367+
end
268368
(p1,p2) = psize(padding,x)
269369
(s1,s2) = psize(stride,x)
270370
(d1,d2) = psize(dilation, x)
271-
M,N,K,Y = Wy*Hy,Cy,Ww*Hw*Cx,Wy*Hy*Cy
272-
yidx = 1
273-
@inbounds for n in 1:Nx
274-
im2col2d!(w, x, x2, n, p1, p2, s1, s2, d1, d2, mode)
275-
gemm!('N','N',M,N,K,alpha,pointer(x2),pointer(w),T(0),pointer(y,yidx))
276-
yidx += Y
277-
end
278-
return y
371+
372+
cdims = ConvDims{(Wx,Hx),(Ww,Hw),Cx,(s1,s2),(p1,p1,p2,p2),(d1,d2), mode == 0}()
373+
return conv2d!(y, x, w, cdims; alpha=alpha)
279374
end
280375

281376
function conv2d_grad_w!(dw::AbstractArray{T,4}, x::AbstractArray{T,4}, dy::AbstractArray{T,4};
@@ -332,37 +427,37 @@ function im2col2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T,
332427
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, mode::Int) where T
333428
Wx,Hx,Cx,Nx = size(x)
334429
Ww,Hw,C1,C2 = w
335-
xn = x[:, :, :, n]
336-
im2col_2d!(xn,x2,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1,1,mode)
430+
xn = view(x, :, :, :, n)
431+
cdims = ConvDims{(Wx,Hx),(Ww,Hw),Cx,(s1,s2),(p1,p1,p2,p2),(1,1), mode == 0}()
432+
im2col_2d!(xn,x2,cdims)
337433
return x2
338434
end
339435

340436
function im2col2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
341437
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, d1::Int, d2::Int, mode::Int) where T
342438
Wx,Hx,Cx,Nx = size(x)
343439
Ww,Hw,C1,C2 = size(w)
344-
xn = x[:, :, :, n]
345-
im2col_2d!(xn,x2,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,d1,d2,mode)
440+
xn = view(x, :, :, :, n)
441+
cdims = ConvDims{(Wx,Hx),(Ww,Hw),Cx,(s1,s2),(p1,p1,p2,p2),(d1,d2), mode == 0}()
442+
im2col_2d!(xn,x2,cdims)
346443
return x2
347444
end
348445

349446
function col2im2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
350447
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, mode::Int) where T
351448
Wx,Hx,Cx,Nx = size(x)
352449
Ww,Hw,C1,C2 = w
353-
xn = x[:, :, :, n]
450+
xn = view(x, :, :, :, n)
354451
col2im_2d!(x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1,1,mode)
355-
x[:, :, :, n] .= xn
356452
return x
357453
end
358454

359455
function col2im2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
360456
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, d1::Int, d2::Int, mode::Int) where T
361457
Wx,Hx,Cx,Nx = size(x)
362458
Ww,Hw,C1,C2 = size(w)
363-
xn = x[:, :, :, n]
459+
xn = view(x, :, :, :, n)
364460
col2im_2d!(x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,d1,d2,mode)
365-
x[:, :, :, n] .= xn
366461
return x
367462
end
368463

@@ -445,7 +540,7 @@ function im2col3d!(w::AbstractArray{T,5}, x::AbstractArray{T,5}, x2::AbstractArr
445540
s3::Int, d1::Int, d2::Int, d3::Int, mode::Int) where T
446541
Wx,Hx,Dx,Cx,Nx = size(x)
447542
Ww,Hw,Dw,C1,C2 = size(w)
448-
xn = x[:, :, :, :, n]
543+
xn = view(x, :, :, :, :, n)
449544
im2col_3d!(xn,x2,Wx,Hx,Dx,Cx,Ww,Hw,Dw,p1,p2,p3,s1,s2,s3,d1,d2,d3,mode)
450545
return x2
451546
end
@@ -455,8 +550,7 @@ function col2im3d!(w::AbstractArray{T,5}, x::AbstractArray{T,5}, x2::AbstractArr
455550
s3::Int, d1::Int, d2::Int, d3::Int, mode::Int) where T
456551
Wx,Hx,Dx,Cx,Nx = size(x)
457552
Ww,Hw,Dw,C1,C2 = size(w)
458-
xn = x[:, :, :, :, n]
553+
xn = view(x, :, :, :, :, n)
459554
col2im_3d!(x2,xn,Wx,Hx,Dx,Cx,Ww,Hw,Dw,p1,p2,p3,s1,s2,s3,d1,d2,d3,mode)
460-
x[:, :, :, :, n] = xn
461555
return x
462556
end

0 commit comments

Comments
 (0)