Skip to content

Commit 136aa82

Browse files
authored
conv_direct!(): The performance fix (#142)
conv_direct!(): The performance fix
2 parents e4fc929 + 3847a8c commit 136aa82

File tree

8 files changed

+203
-130
lines changed

8 files changed

+203
-130
lines changed

src/dim_helpers.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,15 @@ spatial dimension at the end of the spatial dimensions. This does so for a Conv
6262
)
6363
end
6464

65-
@inline function insert_singleton_spatial_dimension(x::AbstractArray)
66-
return reshape(x, size(x)[1:end-2]..., 1, size(x)[end-1:end]...)
65+
# We specialize common cases
66+
@inline function insert_singleton_spatial_dimension(x::AbstractArray{T,3}) where {T}
67+
return reshape(x, size(x,1), 1, size(x,2), size(x,3))
68+
end
69+
@inline function insert_singleton_spatial_dimension(x::AbstractArray{T,4}) where {T}
70+
return reshape(x, size(x,1), size(x,2), 1, size(x,3), size(x,4))
6771
end
6872

69-
# Helper to do this multiple times
73+
# Helper to do this as many times as needed
7074
@inline function insert_singleton_spatial_dimension(x, reps::Int)
7175
for r in 1:reps
7276
x = insert_singleton_spatial_dimension(x)

src/dim_helpers/DenseConvDims.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,16 @@ end
6262

6363
function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DenseConvDims) where {M}
6464
# First, check that channel counts are all correct:
65-
@assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))")
66-
@assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))")
67-
@assert w[end-1] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end-1]) vs. $(channels_in(cdims)))")
68-
@assert w[end] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[end]) vs. $(channels_out(cdims)))")
65+
@assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))")
66+
@assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))")
67+
@assert w[M-1] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M-1]) vs. $(channels_in(cdims)))")
68+
@assert w[M] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[M]) vs. $(channels_out(cdims)))")
6969

7070
# Next, check that the spatial dimensions match up
71-
@assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))")
72-
@assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))")
73-
@assert w[1:end-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:end-2]) vs. $(kernel_size(cdims)))")
71+
@assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))")
72+
@assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))")
73+
@assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))")
7474

7575
# Finally, check that the batch size matches
76-
@assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))")
76+
@assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))")
7777
end

src/dim_helpers/DepthwiseConvDims.jl

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,16 @@ Concrete subclass of `ConvDims` for a depthwise convolution. Differs primarily
77
characterization by C_in, C_mult, rather than C_in, C_out. Useful to be separate from
88
DenseConvDims primarily for channel calculation differences.
99
"""
10-
struct DepthwiseConvDims{N,S,P,D,F} <: ConvDims{N,S,P,D,F}
10+
struct DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F} <: ConvDims{N,S,P,D,F}
1111
I::NTuple{N, Int}
12-
K::NTuple{N, Int}
13-
C_in::Int
14-
C_mult::Int
1512
end
1613

1714
# Getters for the fields
1815
input_size(c::DepthwiseConvDims) = c.I
19-
kernel_size(c::DepthwiseConvDims) = c.K
20-
channels_in(c::DepthwiseConvDims) = c.C_in
21-
channels_out(c::DepthwiseConvDims) = c.C_in * channel_multiplier(c)
22-
channel_multiplier(c::DepthwiseConvDims) = c.C_mult
16+
kernel_size(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = K
17+
channels_in(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = C_in
18+
channels_out(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = C_in * C_mult
19+
channel_multiplier(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = C_mult
2320

2421

2522
# Convenience wrapper to create DepthwiseConvDims objects
@@ -37,22 +34,19 @@ function DepthwiseConvDims(x_size::NTuple{M}, w_size::NTuple{M};
3734

3835
return DepthwiseConvDims{
3936
M - 2,
37+
# Kernel spatial size
38+
w_size[1:end-2],
39+
# Input channels
40+
x_size[end-1],
41+
# Channel multiplier
42+
w_size[end-1],
4043
stride,
4144
padding,
4245
dilation,
4346
flipkernel
4447
}(
4548
# Image spatial size
4649
x_size[1:end-2],
47-
48-
# Kernel spatial size
49-
w_size[1:end-2],
50-
51-
# Input channels
52-
x_size[end-1],
53-
54-
# Channel multiplier
55-
w_size[end-1],
5650
)
5751
end
5852

@@ -69,22 +63,22 @@ end
6963
function DepthwiseConvDims(c::DepthwiseConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c),
7064
C_in=channels_in(c), C_m=channel_multiplier(c), S=stride(c),
7165
P=padding(c), D=dilation(c), F=flipkernel(c))
72-
return DepthwiseConvDims{N, S, P, D, F}(I, K, C_in, C_m)
66+
return DepthwiseConvDims{N, K, C_in, C_m, S, P, D, F}(I)
7367
end
7468

7569
# This one is basically the same as for DenseConvDims, we only change a few lines for kernel channel count
7670
function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DepthwiseConvDims) where {M}
7771
# First, check that channel counts are all correct:
78-
@assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))")
79-
@assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))")
80-
@assert w[end-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[end-1]) vs. $(channel_multiplier(cdims))")
81-
@assert w[end] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end]) vs. $(channels_in(cdims)))")
72+
@assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))")
73+
@assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))")
74+
@assert w[M-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[M-1]) vs. $(channel_multiplier(cdims))")
75+
@assert w[M] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M]) vs. $(channels_in(cdims)))")
8276

8377
# Next, check that the spatial dimensions match up
84-
@assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))")
85-
@assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))")
86-
@assert w[1:end-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:end-2]) vs. $(kernel_size(cdims)))")
78+
@assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))")
79+
@assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))")
80+
@assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))")
8781

8882
# Finally, check that the batch size matches
89-
@assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))")
83+
@assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))")
9084
end

src/impl/conv_direct.jl

Lines changed: 79 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
## This file contains direct Julia implementations of 2d and 3d convolutions
2+
using Base.Threads
23

34
# Helper functions for restricting x/w overreach
45
function clamp_lo(x, w)
@@ -57,50 +58,87 @@ function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5},
5758
stride_w, stride_h, stride_d = stride(cdims)
5859
out_width, out_height, out_depth = output_size(cdims)
5960

60-
# If we're doing crosscorr instead of conv, then don't bother to flip `w`
61-
if !flipkernel(cdims)
62-
w = w[end:-1:1, end:-1:1, end:-1:1, :, :]
63-
end
64-
61+
# Create a method that, at compile-time, determines how we're going to index into `w`
62+
kproj(k, M, cdims::ConvDims{N,S,P,D,true}) where {N, S, P, D} = k
63+
kproj(k, M, cdims::ConvDims{N,S,P,D,false}) where {N, S, P, D} = M - k + 1
64+
6565
# A helper function to project from output (w, h) to input (input_w, input_h)
66-
@inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1
66+
project(idx, stride, pad) = (idx - 1)*stride - pad + 1
6767

68-
# explicit formulation of convolution. Oh hoisting gods, hear my plea.
69-
@inbounds for batch in 1:size(x)[end],
68+
# Use `calc_padding_regions` to determine where we do or don't need to worry about padding
69+
padded_regions, central_region = calc_padding_regions(cdims)
70+
71+
# Start with the central region
72+
w_region, h_region, d_region = central_region
73+
@inbounds for batch in 1:size(x, 5),
74+
c_out in 1:out_c,
75+
d_idx in d_region,
76+
h_idx in h_region,
77+
w_idx in w_region
78+
79+
# Since we're in the central region, we don't need to worry about clamping
80+
dotprod = yT(0)
81+
for c_in in 1:channels_in(cdims),
82+
kd in 1:kernel_d,
83+
kh in 1:kernel_h,
84+
kw in 1:kernel_w
85+
86+
# Hoist me, you coward.
87+
x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d
88+
x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h
89+
x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w
90+
91+
x_val = x[x_w, x_h, x_d, c_in, batch]
92+
w_val = w[kproj(kw, kernel_w, cdims),
93+
kproj(kh, kernel_h, cdims),
94+
kproj(kd, kernel_d, cdims),
95+
c_in, c_out]
96+
dotprod = muladd(x_val, w_val, dotprod)
97+
end
98+
y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch]
99+
end
100+
101+
# Next, do potentially-padded regions:
102+
@inbounds for (w_region, h_region, d_region) in padded_regions,
103+
batch in 1:size(x, 5),
70104
c_out in 1:out_c,
71-
d_idx in 1:out_depth,
72-
h_idx in 1:out_height,
73-
w_idx in 1:out_width
74-
75-
# Starting points of the window of x we're going to grab
76-
x_w = project(w_idx, stride_w, pad_w_lo)
77-
x_h = project(h_idx, stride_h, pad_h_lo)
78-
x_d = project(d_idx, stride_d, pad_d_lo)
79-
80-
# Grow that starting point into ranges
81-
x_widxs = x_w .+ (0:dil_w:(dil_w*kernel_w-1))
82-
x_hidxs = x_h .+ (0:dil_h:(dil_h*kernel_h-1))
83-
x_didxs = x_d .+ (0:dil_d:(dil_d*kernel_d-1))
84-
w_widxs = 1:kernel_w
85-
w_hidxs = 1:kernel_h
86-
w_didxs = 1:kernel_d
87-
88-
# Clamp the ranges to simulate padding
89-
x_widxs, w_widxs = clamp_lo(x_widxs, w_widxs)
90-
x_widxs, w_widxs = clamp_hi(x_widxs, w_widxs, width)
91-
x_hidxs, w_hidxs = clamp_lo(x_hidxs, w_hidxs)
92-
x_hidxs, w_hidxs = clamp_hi(x_hidxs, w_hidxs, height)
93-
x_didxs, w_didxs = clamp_lo(x_didxs, w_didxs)
94-
x_didxs, w_didxs = clamp_hi(x_didxs, w_didxs, depth)
95-
96-
# Grab our slices
97-
x_slice = view(x, x_widxs, x_hidxs, x_didxs, :, batch)
98-
w_slice = view(w, w_widxs, w_hidxs, w_didxs, :, c_out)
99-
100-
# Do the dotproduct dance, then weight by alpha/beta and git 'er done
101-
dotprod = sum(x_slice .* w_slice)
102-
y[w_idx, h_idx, d_idx, c_out, batch] = alpha*convert(yT, dotprod) +
103-
beta*y[w_idx, h_idx, d_idx, c_out, batch]
105+
d_idx in d_region,
106+
h_idx in h_region,
107+
w_idx in w_region
108+
109+
# Probe for out-of-bounds accesses on `x` and `continue` if we hit one
110+
dotprod = yT(0)
111+
for c_in in 1:channels_in(cdims),
112+
kd in 1:kernel_d
113+
114+
x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d
115+
if x_d <= 0 || x_d > depth
116+
continue
117+
end
118+
119+
for kh in 1:kernel_h
120+
x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h
121+
if x_h <= 0 || x_h > height
122+
continue
123+
end
124+
125+
for kw in 1:kernel_w
126+
x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w
127+
if x_w <= 0 || x_w > width
128+
continue
129+
end
130+
131+
x_val = x[x_w, x_h, x_d, c_in, batch]
132+
w_val = w[kproj(kw, kernel_w, cdims),
133+
kproj(kh, kernel_h, cdims),
134+
kproj(kd, kernel_d, cdims),
135+
c_in, c_out]
136+
dotprod = muladd(x_val, w_val, dotprod)
137+
end
138+
end
139+
end
140+
141+
y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch]
104142
end
105143

106144
return y

src/impl/conv_im2col.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function conv_im2col!(
4646
N = channels_out(cdims)
4747
K = prod(kernel_size(cdims))*channels_in(cdims)
4848

49-
@inbounds for batch_idx in 1:size(x,5)
49+
@threads for batch_idx in 1:size(x,5)
5050
# We invoke `@timeit_debug` on the outside of `im2col!()` because inference
5151
# doesn't like us putting it on the inside.
5252
im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
@@ -94,7 +94,7 @@ function ∇conv_filter_im2col!(
9494
N = channels_out(cdims)
9595
K = prod(output_size(cdims))
9696

97-
@inbounds for batch_idx in 1:size(x,5)
97+
@threads for batch_idx in 1:size(x,5)
9898
im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
9999
GC.@preserve col, dw, dy, begin
100100
col_ptr = pointer(col)
@@ -142,7 +142,7 @@ function ∇conv_data_im2col!(
142142
N = prod(kernel_size(cdims))*channels_in(cdims)
143143
K = channels_out(cdims)
144144

145-
@inbounds for batch_idx in 1:size(dx, 5)
145+
@threads for batch_idx in 1:size(dx, 5)
146146
GC.@preserve col, w, dy, begin
147147
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
148148
w_ptr = pointer(w)

0 commit comments

Comments
 (0)