Skip to content

Commit 3e56f3d

Browse files
committed
Rewrite im2col() for greater performance
The performance gain of this PR comes mainly from identifying that the conditionals within the `im2col()` inner loops are really only needed towards the edges of the image being `im2col()`'ed. By separating the loops into two separate segments, we are able to achieve very high memory bandwidth utilizaiton for the majority of the wallclock time of the function. Eventually we will want to parallelize this loop via threading, but until that is possible, this should be a nice improvement. In my testing, this speeds up `im2col()` by a factor of 2-3x. This change also begins the work of pushing some convolutional parameters into the type domain. This is especially helpful for `im2col()`, as it makes it possible that the compiler could completely elide the second set of for loops in the case of no padding (this is not tested; I don't know if the compiler actually is able to elide the loops) but pushing these parameters into the type domain gives a substantial speed boost.
1 parent ccc6dad commit 3e56f3d

File tree

1 file changed

+142
-48
lines changed

1 file changed

+142
-48
lines changed

src/impl/conv.jl

Lines changed: 142 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,112 @@ function psize(p, x)
1010
end
1111
end
1212

13-
function im2col_2d!(img::AbstractArray{T,3}, col::AbstractArray{T,2}, width::Int, height::Int, channels::Int,
14-
kernel_w::Int, kernel_h::Int, pad_w::Int, pad_h::Int, stride_w::Int, stride_h::Int,
15-
dil_w::Int, dil_h::Int, mode::Int) where T
16-
17-
height_col = div(height + 2pad_h - (kernel_h - 1) * dil_h - 1, stride_h) + 1
18-
width_col = div(width + 2pad_w - (kernel_w - 1) * dil_w - 1, stride_w) + 1
19-
channels_col = channels * kernel_h * kernel_w
13+
# Type system-level information about convolution dimensions. Critical for things like
14+
# im2col_2d!() to generate efficient code.
15+
struct ConvDims{img, kernel, channels, stride, padding, dilation, flipkernel} end
16+
img_size(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = I
17+
18+
# Calculate the output dimensions of this convolution
19+
function output_size(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F}
20+
O_w = div(I[1] + P[1] + P[2] - (K[1] - 1) * D[1] - 1, S[1]) + 1
21+
O_h = div(I[2] + P[3] + P[4] - (K[1] - 1) * D[1] - 1, S[1]) + 1
22+
return (O_w, O_h)
23+
end
24+
kernel_size(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = K
25+
img_channels(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = C
26+
stride(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = S
27+
padding(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = P
28+
dilation(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = D
29+
flipkernel(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = F
30+
31+
function im2col_2d!(img::AbstractArray{T,3}, col::AbstractArray{T,2}, cdims::ConvDims) where T
32+
width, height = img_size(cdims)
33+
kernel_w, kernel_h = kernel_size(cdims)
34+
channels = img_channels(cdims)
35+
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi = padding(cdims)
36+
dil_w, dil_h = dilation(cdims)
37+
stride_w, stride_h = stride(cdims)
38+
width_col, height_col = output_size(cdims)
39+
40+
if flipkernel(cdims)
41+
flipk = (w, h) -> (kernel_w - w + 1, kernel_h - h + 1)
42+
else
43+
flipk = (w, h) -> (w, h)
44+
end
2045

21-
#pragma omp parallel for
22-
for c = 1:channels_col
23-
w_offset = (c - 1) % kernel_w
24-
h_offset = div(c - 1, kernel_w) % kernel_h
25-
c_im = div(c - 1, kernel_h * kernel_w)
26-
if mode == 0
27-
w_offset = kernel_w - 1 - w_offset
28-
h_offset = kernel_h - 1 - h_offset
46+
# Reshape col for easy access.
47+
col_reshaped = reshape(col, (width_col, height_col, kernel_w, kernel_h, channels))
48+
49+
# Let us first calculate the number of rows/columns within which we must zero out some
50+
# portion of the image patches we're copying over. Note the subtractions on the `_hi`
51+
# variants are due to us needing to account for padding that is completely ignored due
52+
# to stride/dilation/kernel size combinations.
53+
spill_w_lo = ceil(Int, pad_w_lo/stride_w)
54+
spill_w_hi = width_col - div(width + pad_w_lo - (kernel_w - 1)*dil_w, stride_w)
55+
spill_h_lo = ceil(Int, pad_h_lo/stride_h)
56+
spill_h_hi = height_col - div(height + pad_h_lo - (kernel_h - 1)*dil_h, stride_h)
57+
spill_w_hi_abs = width_col - spill_w_hi + 1
58+
spill_h_hi_abs = height_col - spill_h_hi + 1
59+
60+
# First, a helper function to project from output (w, h) to input (input_w, input_h)
61+
project(idx, stride, pad) = (idx - 1)*stride - pad + 1
62+
63+
# These are the regions we're going to have to run with cognizance of padding
64+
padded_regions = (
65+
(1:width_col, 1:spill_h_lo),
66+
(1:spill_w_lo, (spill_h_lo+1):(spill_h_hi_abs-1)),
67+
(spill_w_hi_abs:width_col, (spill_h_lo+1):(spill_h_hi_abs-1)),
68+
(1:width_col, spill_h_hi_abs:height_col),
69+
)
70+
71+
# We begin by copying the central region of the image which requires no padding at all.
72+
# Eliminating the branches of the fully generalized version below gives us a nice
73+
# speedup on the majority of the data.
74+
for c in 1:channels
75+
for kh in 1:kernel_h
76+
for kw in 1:kernel_w
77+
for h in (spill_h_lo+1):(height_col - spill_h_hi)
78+
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
79+
80+
@inbounds for w in (spill_w_lo+1):(width_col - spill_w_hi)
81+
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
82+
col_reshaped[w, h, flipk(kw, kh)..., c] = img[input_kw, input_kh, c]
83+
end
84+
end
85+
end
2986
end
30-
for h = 1:height_col
31-
for w = 1:width_col
32-
h_pad = (h - 1) * stride_h - pad_h + h_offset * dil_h
33-
w_pad = (w - 1) * stride_w - pad_w + w_offset * dil_w
34-
if h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width
35-
col[((c - 1)*height_col+h-1) * width_col + w] =
36-
img[(c_im * height + h_pad) * width + w_pad + 1]
37-
else
38-
col[((c - 1)*height_col+h - 1) * width_col + w] = 0
87+
end
88+
89+
# For each "padded region", we run the fully general version
90+
for (w_region, h_region) in padded_regions
91+
for c in 1:channels
92+
for kh in 1:kernel_h
93+
for kw in 1:kernel_w
94+
@inbounds for h in h_region
95+
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
96+
97+
# If this column is off the edge, then deal with the entire thing
98+
# in one fell swoop, like a ravenous flock of crows. CAW CAW.
99+
if input_kh <= 0 || input_kh > height
100+
for w in w_region
101+
col_reshaped[w, h, flipk(kw, kh)..., c] = zero(eltype(col_reshaped))
102+
end
103+
continue
104+
end
105+
106+
@inbounds for w in w_region
107+
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
108+
109+
# If this pixel is off the edge of the map, clear it out.
110+
if input_kw <= 0 || input_kw > width
111+
col_reshaped[w, h, flipk(kw, kh)..., c] = zero(eltype(col_reshaped))
112+
continue
113+
end
114+
115+
# Copy the data over
116+
col_reshaped[w, h, flipk(kw, kh)..., c] = img[input_kw, input_kh, c]
117+
end
118+
end
39119
end
40120
end
41121
end
@@ -256,26 +336,41 @@ function depthwiseconv2d_grad_x!(dx::AbstractArray{T,4}, x::AbstractArray{T,4},
256336
return dx
257337
end
258338

339+
function conv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4},
340+
cdims::ConvDims; alpha=T(1)) where T
341+
Wx, Hx = img_size(cdims)
342+
Ww, Hw = kernel_size(cdims)
343+
Wy, Hy = output_size(cdims)
344+
Cx = img_channels(cdims)
345+
M, N, K, Y = Wy*Hy, size(y,4), Ww*Hw*Cx, Wy*Hy*size(y, 4)
346+
347+
x2 = similar(x, im2col_dims(w, y))
348+
@inbounds for n in 1:size(x,4)
349+
im2col_2d!(view(x, :, :, :, n), x2, cdims)
350+
gemm!('N','N',M,N,K,alpha,pointer(x2),pointer(w),T(0),pointer(y,(n - 1)*Y + 1))
351+
end
352+
return y
353+
end
354+
259355
function conv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
260356
padding=0, stride=1, dilation=1, mode=0, alpha=T(1)) where T
261-
if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end
357+
if mode != 0 && mode != 1
358+
throw(ArgumentError("conv2d only supports mode=0 or 1."))
359+
end
262360
Wx,Hx,Cx,Nx = size(x)
263361
Ww,Hw,C1,C2 = size(w)
264-
if Cx!=C1; throw(DimensionMismatch()); end
265-
Wy,Hy,Cy,Ny = size(y)
266-
x2dims = im2col_dims(w,y)
267-
x2 = similar(x, x2dims)
362+
363+
# Check that the number of channels in `x` matches the number of channels in each
364+
# kernel of `w`. IF it doesn't, throw a DimensionMismatch()
365+
if Cx != C1
366+
throw(DimensionMismatch())
367+
end
268368
(p1,p2) = psize(padding,x)
269369
(s1,s2) = psize(stride,x)
270370
(d1,d2) = psize(dilation, x)
271-
M,N,K,Y = Wy*Hy,Cy,Ww*Hw*Cx,Wy*Hy*Cy
272-
yidx = 1
273-
@inbounds for n in 1:Nx
274-
im2col2d!(w, x, x2, n, p1, p2, s1, s2, d1, d2, mode)
275-
gemm!('N','N',M,N,K,alpha,pointer(x2),pointer(w),T(0),pointer(y,yidx))
276-
yidx += Y
277-
end
278-
return y
371+
372+
cdims = ConvDims{(Wx,Hx),(Ww,Hw),Cx,(s1,s2),(p1,p1,p2,p2),(d1,d2), mode == 0}()
373+
return conv2d!(y, x, w, cdims; alpha=alpha)
279374
end
280375

281376
function conv2d_grad_w!(dw::AbstractArray{T,4}, x::AbstractArray{T,4}, dy::AbstractArray{T,4};
@@ -332,37 +427,37 @@ function im2col2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T,
332427
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, mode::Int) where T
333428
Wx,Hx,Cx,Nx = size(x)
334429
Ww,Hw,C1,C2 = w
335-
xn = x[:, :, :, n]
336-
im2col_2d!(xn,x2,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1,1,mode)
430+
xn = view(x, :, :, :, n)
431+
cdims = ConvDims{(Wx,Hx),(Ww,Hw),Cx,(s1,s2),(p1,p1,p2,p2),(1,1), mode == 0}()
432+
im2col_2d!(xn,x2,cdims)
337433
return x2
338434
end
339435

340436
function im2col2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
341437
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, d1::Int, d2::Int, mode::Int) where T
342438
Wx,Hx,Cx,Nx = size(x)
343439
Ww,Hw,C1,C2 = size(w)
344-
xn = x[:, :, :, n]
345-
im2col_2d!(xn,x2,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,d1,d2,mode)
440+
xn = view(x, :, :, :, n)
441+
cdims = ConvDims{(Wx,Hx),(Ww,Hw),Cx,(s1,s2),(p1,p1,p2,p2),(d1,d2), mode == 0}()
442+
im2col_2d!(xn,x2,cdims)
346443
return x2
347444
end
348445

349446
function col2im2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
350447
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, mode::Int) where T
351448
Wx,Hx,Cx,Nx = size(x)
352449
Ww,Hw,C1,C2 = w
353-
xn = x[:, :, :, n]
450+
xn = view(x, :, :, :, n)
354451
col2im_2d!(x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1,1,mode)
355-
x[:, :, :, n] .= xn
356452
return x
357453
end
358454

359455
function col2im2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArray{T,2},
360456
n::Int, p1::Int, p2::Int, s1::Int, s2::Int, d1::Int, d2::Int, mode::Int) where T
361457
Wx,Hx,Cx,Nx = size(x)
362458
Ww,Hw,C1,C2 = size(w)
363-
xn = x[:, :, :, n]
459+
xn = view(x, :, :, :, n)
364460
col2im_2d!(x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,d1,d2,mode)
365-
x[:, :, :, n] .= xn
366461
return x
367462
end
368463

@@ -445,7 +540,7 @@ function im2col3d!(w::AbstractArray{T,5}, x::AbstractArray{T,5}, x2::AbstractArr
445540
s3::Int, d1::Int, d2::Int, d3::Int, mode::Int) where T
446541
Wx,Hx,Dx,Cx,Nx = size(x)
447542
Ww,Hw,Dw,C1,C2 = size(w)
448-
xn = x[:, :, :, :, n]
543+
xn = view(x, :, :, :, :, n)
449544
im2col_3d!(xn,x2,Wx,Hx,Dx,Cx,Ww,Hw,Dw,p1,p2,p3,s1,s2,s3,d1,d2,d3,mode)
450545
return x2
451546
end
@@ -455,8 +550,7 @@ function col2im3d!(w::AbstractArray{T,5}, x::AbstractArray{T,5}, x2::AbstractArr
455550
s3::Int, d1::Int, d2::Int, d3::Int, mode::Int) where T
456551
Wx,Hx,Dx,Cx,Nx = size(x)
457552
Ww,Hw,Dw,C1,C2 = size(w)
458-
xn = x[:, :, :, :, n]
553+
xn = view(x, :, :, :, :, n)
459554
col2im_3d!(x2,xn,Wx,Hx,Dx,Cx,Ww,Hw,Dw,p1,p2,p3,s1,s2,s3,d1,d2,d3,mode)
460-
x[:, :, :, :, n] = xn
461555
return x
462556
end

0 commit comments

Comments
 (0)