Skip to content

Commit 1365788

Browse files
committed
Improve performance of tiny convolutions with the direct family of convolutions
Our approach is two-fold; use `calc_padding_regions()` to give us a fast-path for the central part of a convolution, and also eliminate allocations. We also move a little bit more information into compile-time.
1 parent 231532d commit 1365788

File tree

7 files changed

+244
-44
lines changed

7 files changed

+244
-44
lines changed

src/dim_helpers.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,15 @@ spatial dimension at the end of the spatial dimensions. This does so for a Conv
6262
)
6363
end
6464

65-
@inline function insert_singleton_spatial_dimension(x::AbstractArray)
66-
return reshape(x, size(x)[1:end-2]..., 1, size(x)[end-1:end]...)
65+
# We specialize common cases
66+
@inline function insert_singleton_spatial_dimension(x::AbstractArray{T,3}) where {T}
67+
return reshape(x, size(x,1), 1, size(x,2), size(x,3))
68+
end
69+
@inline function insert_singleton_spatial_dimension(x::AbstractArray{T,4}) where {T}
70+
return reshape(x, size(x,1), size(x,2), 1, size(x,3), size(x,4))
6771
end
6872

69-
# Helper to do this multiple times
73+
# Helper to do this as many times as needed
7074
@inline function insert_singleton_spatial_dimension(x, reps::Int)
7175
for r in 1:reps
7276
x = insert_singleton_spatial_dimension(x)

src/dim_helpers/DenseConvDims.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,16 @@ end
6262

6363
function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DenseConvDims) where {M}
6464
# First, check that channel counts are all correct:
65-
@assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))")
66-
@assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))")
67-
@assert w[end-1] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end-1]) vs. $(channels_in(cdims)))")
68-
@assert w[end] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[end]) vs. $(channels_out(cdims)))")
65+
@assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))")
66+
@assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))")
67+
@assert w[M-1] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M-1]) vs. $(channels_in(cdims)))")
68+
@assert w[M] == channels_out(cdims) DimensionMismatch("Kernel output channel count ($(w[M]) vs. $(channels_out(cdims)))")
6969

7070
# Next, check that the spatial dimensions match up
71-
@assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))")
72-
@assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))")
73-
@assert w[1:end-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:end-2]) vs. $(kernel_size(cdims)))")
71+
@assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))")
72+
@assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))")
73+
@assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))")
7474

7575
# Finally, check that the batch size matches
76-
@assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))")
76+
@assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))")
7777
end

src/dim_helpers/DepthwiseConvDims.jl

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,16 @@ Concrete subclass of `ConvDims` for a depthwise convolution. Differs primarily
77
characterization by C_in, C_mult, rather than C_in, C_out. Useful to be separate from
88
DenseConvDims primarily for channel calculation differences.
99
"""
10-
struct DepthwiseConvDims{N,S,P,D,F} <: ConvDims{N,S,P,D,F}
10+
struct DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F} <: ConvDims{N,S,P,D,F}
1111
I::NTuple{N, Int}
12-
K::NTuple{N, Int}
13-
C_in::Int
14-
C_mult::Int
1512
end
1613

1714
# Getters for the fields
1815
input_size(c::DepthwiseConvDims) = c.I
19-
kernel_size(c::DepthwiseConvDims) = c.K
20-
channels_in(c::DepthwiseConvDims) = c.C_in
21-
channels_out(c::DepthwiseConvDims) = c.C_in * channel_multiplier(c)
22-
channel_multiplier(c::DepthwiseConvDims) = c.C_mult
16+
kernel_size(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = K
17+
channels_in(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = C_in
18+
channels_out(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = C_in * C_mult
19+
channel_multiplier(c::DepthwiseConvDims{N,K,C_in,C_mult,S,P,D,F}) where {N,K,C_in,C_mult,S,P,D,F} = C_mult
2320

2421

2522
# Convenience wrapper to create DepthwiseConvDims objects
@@ -37,22 +34,19 @@ function DepthwiseConvDims(x_size::NTuple{M}, w_size::NTuple{M};
3734

3835
return DepthwiseConvDims{
3936
M - 2,
37+
# Kernel spatial size
38+
w_size[1:end-2],
39+
# Input channels
40+
x_size[end-1],
41+
# Channel multiplier
42+
w_size[end-1],
4043
stride,
4144
padding,
4245
dilation,
4346
flipkernel
4447
}(
4548
# Image spatial size
4649
x_size[1:end-2],
47-
48-
# Kernel spatial size
49-
w_size[1:end-2],
50-
51-
# Input channels
52-
x_size[end-1],
53-
54-
# Channel multiplier
55-
w_size[end-1],
5650
)
5751
end
5852

@@ -69,22 +63,22 @@ end
6963
function DepthwiseConvDims(c::DepthwiseConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c),
7064
C_in=channels_in(c), C_m=channel_multiplier(c), S=stride(c),
7165
P=padding(c), D=dilation(c), F=flipkernel(c))
72-
return DepthwiseConvDims{N, S, P, D, F}(I, K, C_in, C_m)
66+
return DepthwiseConvDims{N, K, C_in, C_m, S, P, D, F}(I)
7367
end
7468

7569
# This one is basically the same as for DenseConvDims, we only change a few lines for kernel channel count
7670
function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DepthwiseConvDims) where {M}
7771
# First, check that channel counts are all correct:
78-
@assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[end-1]) vs. $(channels_in(cdims)))")
79-
@assert y[end-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[end-1]) vs. $(channels_out(cdims)))")
80-
@assert w[end-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[end-1]) vs. $(channel_multiplier(cdims))")
81-
@assert w[end] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[end]) vs. $(channels_in(cdims)))")
72+
@assert x[M-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))")
73+
@assert y[M-1] == channels_out(cdims) DimensionMismatch("Data output channel count ($(y[M-1]) vs. $(channels_out(cdims)))")
74+
@assert w[M-1] == channel_multiplier(cdims) DimensionMismatch("Kernel multiplier channel count ($(w[M-1]) vs. $(channel_multiplier(cdims))")
75+
@assert w[M] == channels_in(cdims) DimensionMismatch("Kernel input channel count ($(w[M]) vs. $(channels_in(cdims)))")
8276

8377
# Next, check that the spatial dimensions match up
84-
@assert x[1:end-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:end-2]) vs. $(input_size(cdims)))")
85-
@assert y[1:end-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:end-2]) vs. $(output_size(cdims)))")
86-
@assert w[1:end-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:end-2]) vs. $(kernel_size(cdims)))")
78+
@assert x[1:M-2] == input_size(cdims) DimensionMismatch("Data input spatial size ($(x[1:M-2]) vs. $(input_size(cdims)))")
79+
@assert y[1:M-2] == output_size(cdims) DimensionMismatch("Data output spatial size ($(y[1:M-2]) vs. $(output_size(cdims)))")
80+
@assert w[1:M-2] == kernel_size(cdims) DimensionMismatch("Kernel spatial size ($(w[1:M-2]) vs. $(kernel_size(cdims)))")
8781

8882
# Finally, check that the batch size matches
89-
@assert x[end] == y[end] DimensionMismatch("Batch size ($(x[end]) vs. $(y[end]))")
83+
@assert x[M] == y[M] DimensionMismatch("Batch size ($(x[M]) vs. $(y[M]))")
9084
end

src/impl/conv_direct.jl

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
## This file contains direct Julia implementations of 2d and 3d convolutions
2+
using Base.Threads
23

34
# Helper functions for restricting x/w overreach
45
function clamp_lo(x, w)
@@ -44,7 +45,7 @@ wrapper methods are available.
4445
"""
4546
conv_direct!
4647

47-
function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5},
48+
function conv_direct_old!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5},
4849
w::AbstractArray{wT,5}, cdims::DenseConvDims;
4950
alpha::yT = yT(1), beta = false) where {yT, xT, wT}
5051
check_dims(size(x), size(w), size(y), cdims)
@@ -106,6 +107,105 @@ function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5},
106107
return y
107108
end
108109

110+
function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5},
111+
w::AbstractArray{wT,5}, cdims::DenseConvDims;
112+
alpha::yT = yT(1), beta = false) where {yT, xT, wT}
113+
check_dims(size(x), size(w), size(y), cdims)
114+
115+
width, height, depth = input_size(cdims)
116+
kernel_w, kernel_h, kernel_d = kernel_size(cdims)
117+
out_c = channels_out(cdims)
118+
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims)
119+
dil_w, dil_h, dil_d = dilation(cdims)
120+
stride_w, stride_h, stride_d = stride(cdims)
121+
out_width, out_height, out_depth = output_size(cdims)
122+
123+
# Create a method that, at compile-time, determines how we're going to index into `w`
124+
kproj(k, M, cdims::ConvDims{N,S,P,D,true}) where {N, S, P, D} = k
125+
kproj(k, M, cdims::ConvDims{N,S,P,D,false}) where {N, S, P, D} = M - k + 1
126+
127+
# A helper function to project from output (w, h) to input (input_w, input_h)
128+
project(idx, stride, pad) = (idx - 1)*stride - pad + 1
129+
130+
# Use `calc_padding_regions` to determine where we do or don't need to worry about padding
131+
padded_regions, central_region = calc_padding_regions(cdims)
132+
133+
# Start with the central region
134+
w_region, h_region, d_region = central_region
135+
@inbounds for batch in 1:size(x, 5),
136+
c_out in 1:out_c,
137+
d_idx in d_region,
138+
h_idx in h_region,
139+
w_idx in w_region
140+
141+
# Since we're in the central region, we don't need to worry about clamping
142+
dotprod = yT(0)
143+
for c_in in 1:channels_in(cdims),
144+
kd in 1:kernel_d,
145+
kh in 1:kernel_h,
146+
kw in 1:kernel_w
147+
148+
# Hoist me, you coward.
149+
x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d
150+
x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h
151+
x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w
152+
153+
x_val = x[x_w, x_h, x_d, c_in, batch]
154+
w_val = w[kproj(kw, kernel_w, cdims),
155+
kproj(kh, kernel_h, cdims),
156+
kproj(kd, kernel_d, cdims),
157+
c_in, c_out]
158+
dotprod = muladd(x_val, w_val, dotprod)
159+
end
160+
y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch]
161+
end
162+
163+
# Next, do potentially-padded regions:
164+
@inbounds for (w_region, h_region, d_region) in padded_regions,
165+
batch in 1:size(x, 5),
166+
c_out in 1:out_c,
167+
d_idx in d_region,
168+
h_idx in h_region,
169+
w_idx in w_region
170+
171+
# Probe for out-of-bounds accesses on `x` and `continue` if we hit one
172+
dotprod = yT(0)
173+
for c_in in 1:channels_in(cdims),
174+
kd in 1:kernel_d
175+
176+
x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d
177+
if x_d <= 0 || x_d > depth
178+
continue
179+
end
180+
181+
for kh in 1:kernel_h
182+
x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h
183+
if x_h <= 0 || x_h > height
184+
continue
185+
end
186+
187+
for kw in 1:kernel_w
188+
x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w
189+
if x_w <= 0 || x_w > width
190+
continue
191+
end
192+
193+
x_val = x[x_w, x_h, x_d, c_in, batch]
194+
w_val = w[kproj(kw, kernel_w, cdims),
195+
kproj(kh, kernel_h, cdims),
196+
kproj(kd, kernel_d, cdims),
197+
c_in, c_out]
198+
dotprod = muladd(x_val, w_val, dotprod)
199+
end
200+
end
201+
end
202+
203+
y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch]
204+
end
205+
206+
return y
207+
end
208+
109209
## Gradient definitions
110210
"""
111211
∇conv_data_direct!(dx, dy, w, cdims; alpha=1, beta=0)

src/impl/conv_im2col.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function conv_im2col!(
4646
N = channels_out(cdims)
4747
K = prod(kernel_size(cdims))*channels_in(cdims)
4848

49-
@inbounds for batch_idx in 1:size(x,5)
49+
@threads for batch_idx in 1:size(x,5)
5050
# We invoke `@timeit_debug` on the outside of `im2col!()` because inference
5151
# doesn't like us putting it on the inside.
5252
im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
@@ -94,7 +94,7 @@ function ∇conv_filter_im2col!(
9494
N = channels_out(cdims)
9595
K = prod(output_size(cdims))
9696

97-
@inbounds for batch_idx in 1:size(x,5)
97+
@threads for batch_idx in 1:size(x,5)
9898
im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
9999
GC.@preserve col, dw, dy, begin
100100
col_ptr = pointer(col)
@@ -142,7 +142,7 @@ function ∇conv_data_im2col!(
142142
N = prod(kernel_size(cdims))*channels_in(cdims)
143143
K = channels_out(cdims)
144144

145-
@inbounds for batch_idx in 1:size(dx, 5)
145+
@threads for batch_idx in 1:size(dx, 5)
146146
GC.@preserve col, w, dy, begin
147147
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
148148
w_ptr = pointer(w)

src/impl/depthwiseconv_direct.jl

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ channels in `x` is the last, not the second-to-last, as in a normal dense convol
1818
1919
See the docstring for `conv_direct!()` for more on the optional parameters.
2020
"""
21-
function depthwiseconv_direct!(
21+
function depthwiseconv_direct_old!(
2222
y::AbstractArray{yT,5}, x::AbstractArray{xT,5},
2323
w::AbstractArray{wT,5}, cdims::DepthwiseConvDims;
2424
alpha::yT = yT(1), beta::yT = yT(0)) where {yT, xT, wT}
@@ -83,6 +83,108 @@ function depthwiseconv_direct!(
8383
return y
8484
end
8585

86+
function depthwiseconv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5},
87+
w::AbstractArray{wT,5}, cdims::DepthwiseConvDims;
88+
alpha::yT = yT(1), beta = false) where {yT, xT, wT}
89+
check_dims(size(x), size(w), size(y), cdims)
90+
91+
width, height, depth = input_size(cdims)
92+
kernel_w, kernel_h, kernel_d = kernel_size(cdims)
93+
out_c = channels_out(cdims)
94+
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims)
95+
dil_w, dil_h, dil_d = dilation(cdims)
96+
stride_w, stride_h, stride_d = stride(cdims)
97+
out_width, out_height, out_depth = output_size(cdims)
98+
99+
# Create a method that, at compile-time, determines how we're going to index into `w`
100+
kproj(k, M, cdims::DepthwiseConvDims{N,K,C_mult,C_in,S,P,D,true}) where {N, K, C_mult, C_in, S, P, D} = k
101+
kproj(k, M, cdims::DepthwiseConvDims{N,K,C_mult,C_in,S,P,D,false}) where {N, K, C_mult, C_in, S, P, D} = M - k + 1
102+
103+
# A helper function to project from output (w, h) to input (input_w, input_h)
104+
project(idx, stride, pad) = (idx - 1)*stride - pad + 1
105+
106+
# Use `calc_padding_regions` to determine where we do or don't need to worry about padding
107+
padded_regions, central_region = calc_padding_regions(cdims)
108+
109+
# Start with the central region
110+
w_region, h_region, d_region = central_region
111+
@inbounds for batch in 1:size(x)[end],
112+
c_mult in 1:channel_multiplier(cdims),
113+
c_in in 1:channels_in(cdims),
114+
d_idx in d_region,
115+
h_idx in h_region,
116+
w_idx in w_region
117+
118+
# Since we're in the central region, we don't need to worry about clamping
119+
dotprod = yT(0)
120+
c_out = (c_in - 1)*channel_multiplier(cdims) + c_mult
121+
for kd in 1:kernel_d,
122+
kh in 1:kernel_h,
123+
kw in 1:kernel_w
124+
125+
# Hoist me, you coward.
126+
x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d
127+
x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h
128+
x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w
129+
130+
x_val = x[x_w, x_h, x_d, c_in, batch]
131+
w_val = w[kproj(kw, kernel_w, cdims),
132+
kproj(kh, kernel_h, cdims),
133+
kproj(kd, kernel_d, cdims),
134+
c_mult, c_in]
135+
dotprod = muladd(x_val, w_val, dotprod)
136+
end
137+
y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch]
138+
end
139+
140+
# Next, do potentially-padded regions:
141+
@inbounds for (w_region, h_region, d_region) in padded_regions,
142+
batch in 1:size(x)[end],
143+
c_mult in 1:channel_multiplier(cdims),
144+
c_in in 1:channels_in(cdims),
145+
d_idx in d_region,
146+
h_idx in h_region,
147+
w_idx in w_region
148+
149+
# Probe for out-of-bounds accesses on `x` and `continue` if we hit one
150+
dotprod = yT(0)
151+
c_out = (c_in - 1)*channel_multiplier(cdims) + c_mult
152+
for c_in in 1:channels_in(cdims),
153+
kd in 1:kernel_d
154+
155+
x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d
156+
if x_d <= 0 || x_d > depth
157+
continue
158+
end
159+
160+
for kh in 1:kernel_h
161+
x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h
162+
if x_h <= 0 || x_h > height
163+
continue
164+
end
165+
166+
for kw in 1:kernel_w
167+
x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w
168+
if x_w <= 0 || x_w > width
169+
continue
170+
end
171+
172+
x_val = x[x_w, x_h, x_d, c_in, batch]
173+
w_val = w[kproj(kw, kernel_w, cdims),
174+
kproj(kh, kernel_h, cdims),
175+
kproj(kd, kernel_d, cdims),
176+
c_mult, c_in]
177+
dotprod = muladd(x_val, w_val, dotprod)
178+
end
179+
end
180+
end
181+
182+
y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch]
183+
end
184+
185+
return y
186+
end
187+
86188
"""
87189
∇depthwiseconv_data_direct!(dx, dy, w, cdims; alpha=1, beta=0)
88190

0 commit comments

Comments
 (0)