Skip to content

Commit 8d3e059

Browse files
author
Anton Smirnov
authored
Fix convolution & pooling type-stability (#370)
1 parent 3906ec1 commit 8d3e059

13 files changed

+435
-342
lines changed

src/conv.jl

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,20 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!
2626
# cdims = ConvDims(x, w; stride=2, dilation=(3,2))
2727
# dx = ∇conv_data(conv(x, w, cdims), w, cdims)
2828

29-
# The computational flow, starting from the user facing functions,
30-
# goes through the following steps:
29+
# The computational flow, starting from the user facing functions,
30+
# goes through the following steps:
3131
#
32-
# STEP 1:
32+
# STEP 1:
3333
# use ConvDims objects (only for `conv` and `depthwiseconv`)
34-
# STEP 2:
34+
# STEP 2:
3535
# define autoallocating version (frontend and implementations)
36-
# STEP 3:
36+
# STEP 3:
3737
# reshape to 3d convolutions (frontend and implementions)
38-
# STEP 4:
38+
# STEP 4:
3939
# choose implementation
4040

4141
# TODO: should we also add
42-
# STEP X:
42+
# STEP X:
4343
# use homogeneus datatypes
4444
# to handle etherogeneus inputs now handled by conv_direct?
4545

@@ -48,22 +48,23 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!
4848
"""
4949
conv(x, w; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1)
5050
51-
Apply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors
52-
in 1d/2d/3d convolutions respectively.
51+
Apply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors
52+
in 1d/2d/3d convolutions respectively.
5353
"""
54-
function conv(x, w::AbstractArray{T, N}; stride=1, pad=0, dilation=1, flipped=false, groups = 1) where {T, N}
55-
stride = expand(Val(N-2), stride)
56-
pad = expand(Val(N-2), pad)
57-
dilation = expand(Val(N-2), dilation)
58-
cdims = DenseConvDims(x, w; stride=stride, padding=pad, dilation=dilation, flipkernel=flipped, groups = groups)
54+
function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) where {T, N}
55+
stride = expand(Val(N - 2), stride)
56+
padding = expand(Val(N - 2), pad)
57+
dilation = expand(Val(N - 2), dilation)
58+
cdims = DenseConvDims(
59+
size(x), size(w); stride, padding, dilation, flipkernel=flipped, groups)
5960
return conv(x, w, cdims)
6061
end
6162

6263
"""
6364
depthwiseconv(x, w; stride=1, pad=0, dilation=1, flipped=false)
6465
65-
Depthwise convolution operation with filter `w` on input `x`. `x` and `w`
66-
are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively.
66+
Depthwise convolution operation with filter `w` on input `x`. `x` and `w`
67+
are 3d/4d/5d tensors in 1d/2d/3d convolutions respectively.
6768
"""
6869
function depthwiseconv(x, w::AbstractArray{T, N}; stride=1, pad=0, dilation=1, flipped=false) where {T, N}
6970
stride = expand(Val(N-2), stride)
@@ -98,9 +99,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
9899
function $(Symbol("$(name)$(backend)"))(
99100
dy::AbstractArray{yT,N}, w::AbstractArray{wT,N},
100101
cdims::C; kwargs...) where {yT, wT, N, C <: ConvDims}
101-
dx = similar(dy, input_size(cdims)..., channels_in(cdims),
102-
size(dy, N))
103-
102+
dx = similar(dy, input_size(cdims)..., channels_in(cdims), size(dy, N))
104103
return $(Symbol("$(name)$(backend)!"))(dx, dy, w, cdims; kwargs...)
105104
end
106105
end
@@ -114,7 +113,6 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
114113
cdims::ConvDims; kwargs...) where {xT, yT, N}
115114
dw = similar(dy, kernel_size(cdims)..., channels_in(cdims) ÷ groupcount(cdims),
116115
channels_out(cdims))
117-
118116
return $(Symbol("∇conv_filter$(backend)!"))(dw, x, dy, cdims; kwargs...)
119117
end
120118
end
@@ -197,15 +195,15 @@ for (front_name, backend) in (
197195
G = 1,
198196
C_in = channels_in(cdims) ÷ groupcount(cdims),
199197
C_out = channels_out(cdims) ÷ groupcount(cdims))
200-
198+
201199
Threads.@sync for (xc, wc) in zip(x_cs, w_cs)
202200
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
203201
w = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
204202
y = @view out[ntuple(i -> i == 4 ? wc : Colon(), 5)...]
205203
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(y, x, w, cdims2; kwargs...)
206204
end
207205

208-
return out
206+
return out
209207
end
210208
end
211209
end
@@ -232,12 +230,11 @@ function ∇conv_data!(out::AbstractArray{T,5}, in1::AbstractArray{T,5},
232230
Threads.@spawn ∇conv_data_im2col!(dxv, dyv, wv, cdims2; kwargs...)
233231
end
234232

235-
return out
233+
return out
236234
end
237235

238236
function ∇conv_filter!(out::AbstractArray{T,5}, in1::AbstractArray{T,5},
239237
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: G, C <: ConvDims}
240-
241238
dw_cs = Iterators.partition(1:size(out, 5),
242239
channels_out(cdims) ÷ groupcount(cdims))
243240
dy_cs = Iterators.partition(1:size(in2, 4),
@@ -256,7 +253,7 @@ function ∇conv_filter!(out::AbstractArray{T,5}, in1::AbstractArray{T,5},
256253
Threads.@spawn ∇conv_filter_im2col!(dw, x, dy, cdims2; kwargs...)
257254
end
258255

259-
return out
256+
return out
260257
end
261258

262259

src/conv_bias_act.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
export conv_bias_act, conv_bias_act!
22

3-
function conv_bias_act(x::AbstractArray{xT,N}, w::AbstractArray{wT,N},
3+
function conv_bias_act(x::AbstractArray{xT,N}, w::AbstractArray{wT,N},
44
cdims::ConvDims, b::AbstractArray{bT,N}, σ=identity; kwargs...) where {xT, wT, bT, N}
55
y = similar(x, promote_type(xT, wT, bT), output_size(cdims)..., channels_out(cdims), size(x,N))
66
conv_bias_act!(y, x, w, cdims, b, σ; kwargs...)
77
return y
88
end
99

10-
function conv_bias_act!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, w::AbstractArray{wT,5},
10+
function conv_bias_act!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5}, w::AbstractArray{wT,5},
1111
cdims::ConvDims, b::AbstractArray{bT,5}, σ=identity; kwargs...) where {yT, xT, wT, bT}
1212
conv!(y, x, w, cdims)
1313
y .= σ.(y .+ b)

src/dim_helpers.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,13 @@ function transpose_pad(cdims::ConvDims)
4545
end
4646

4747
"""
48-
insert_singleton_spatial_dimension(cdims::DenseConvDims)
48+
insert_singleton_spatial_dimension(cdims::ConvDims)
4949
5050
When converting a 1d convolution to a 2d, or a 2d to a 3d, we need to insert a singleton
5151
spatial dimension at the end of the spatial dimensions. This does so for a ConvDims.
5252
"""
5353
@inline function insert_singleton_spatial_dimension(cdims::C) where {C <: ConvDims}
5454
return basetype(C)(cdims;
55-
N=spatial_dims(cdims) + 1,
5655
I=(input_size(cdims)..., 1),
5756
K=(kernel_size(cdims)..., 1),
5857
S=(stride(cdims)..., 1),

src/dim_helpers/ConvDims.jl

Lines changed: 55 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,19 @@ export ConvDims
66
Type system-level information about convolution dimensions. Critical for things like
77
`im2col!()` to generate efficient code, and helpful to reduce the number of kwargs
88
getting passed around.
9-
10-
We don't want to specialize on things like image size/channel count, so we generally
11-
store those as fields, just for convenience, and to allow for non-breaking changes when
12-
we decide we _do_ want to specialize on those values. We always want to specialize on
13-
things like stride, padding, dilation, and kernel flipping though.
149
"""
15-
abstract type ConvDims{N, S, P, D, F} end
10+
abstract type ConvDims{N} end
11+
12+
@inline spatial_dims(::ConvDims{N}) where N = N
13+
@inline groupcount(c::ConvDims) = 1
14+
15+
# Below functions should be implemented by dims that subtype `ConvDims`.
16+
function input_size end
17+
function kernel_size end
18+
function stride end
19+
function padding end
20+
function dilation end
21+
function flipkernel end
1622

1723
# Hack to get rid of type parameters
1824
function basetype(::Type{C}) where {C <: ConvDims}
@@ -27,13 +33,29 @@ function basetype(::Type{C}) where {C <: ConvDims}
2733
end
2834
end
2935

30-
# Obvious getter definitions for the type system-level definitions
31-
spatial_dims(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = N
32-
stride(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = S
33-
padding(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = P
34-
dilation(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = D
35-
flipkernel(c::ConvDims{N,S,P,D,F}) where {N, S, P, D, F} = F
36-
groupcount(c::ConvDims) = 1
36+
function output_size(c::ConvDims)
37+
I = input_size(c)
38+
K = kernel_size(c)
39+
S = stride(c)
40+
P = padding(c)
41+
D = dilation(c)
42+
43+
return ntuple(spatial_dims(c)) do i
44+
return div(I[i] + P[(i-1)*2 + 1] + P[(i-1)*2 + 2] - (K[i] - 1) * D[i] - 1, S[i]) + 1
45+
end
46+
end
47+
48+
function Base.show(io::IO, cdims::C) where {C <: ConvDims}
49+
I = (input_size(cdims)..., channels_in(cdims))
50+
O = (output_size(cdims)..., channels_out(cdims))
51+
K = kernel_size(cdims)
52+
S = stride(cdims)
53+
P = padding(cdims)
54+
D = dilation(cdims)
55+
F = flipkernel(cdims)
56+
G = groupcount(cdims)
57+
print(io, "$(basetype(C)): $I * $K -> $O, stride: $S, pad: $P, dil: $D, flip: $F, groups: $G")
58+
end
3759

3860
"""
3961
im2col_dims(c::ConvDims)
@@ -81,57 +103,31 @@ function check_spdf(x_size::NTuple{N}, w_size::NTuple{N}, stride, padding, dilat
81103

82104
# padding is kind of a special case; we allow it to be either 2-length or 4-length,
83105
# since we support asymmetrical padding
84-
if length(ppadding) != 2*nd
85-
if length(ppadding) == nd
86-
# Do this repeat dance so that we get lo/hi symmetrical padding
87-
ppadding = tuple(repeat(collect(ppadding), inner=2)...)
88-
else
89-
throw(DimensionMismatch("Padding $(length(ppadding))d, should be either $(nd)d or $(2*nd)d!"))
90-
end
106+
if length(ppadding) == 2 * nd
107+
_validate_padding(x_size, w_size, ppadding, pdilation)
108+
return pstride, ppadding, pdilation
91109
end
92110

93-
# Assert that kernel size * dilation is <= padded input size
94-
for idx in 1:nd
111+
length(ppadding) != nd && throw(DimensionMismatch(
112+
"Padding $(length(ppadding))d, should be either $(nd)d or $(2*nd)d!"))
113+
114+
# Do this repeat dance so that we get lo/hi symmetrical padding
115+
ppadding_expanded = ntuple(i -> ppadding[(i - 1) ÷ 2 + 1], 2 * nd)
116+
_validate_padding(x_size, w_size, ppadding_expanded, pdilation)
117+
return pstride, ppadding_expanded, pdilation
118+
end
119+
120+
# Assert that kernel size * dilation is <= padded input size
121+
function _validate_padding(x_size::NTuple{N}, w_size::NTuple{N}, padding, dilation) where N
122+
for idx in 1:(N - 2)
95123
Is = x_size[idx]
96-
Pl = ppadding[(idx - 1)*2 + 1]
97-
Ph = ppadding[(idx - 1)*2 + 2]
98124
Ks = w_size[idx]
99-
Ds = pdilation[idx]
100-
if Is + Pl + Ph < (Ks - 1)*Ds + 1
125+
Pl = padding[(idx - 1) * 2 + 1]
126+
Ph = padding[(idx - 1) * 2 + 2]
127+
Ds = dilation[idx]
128+
if Is + Pl + Ph < (Ks - 1) * Ds + 1
101129
throw(DimensionMismatch("Kernel * dilation (($Ks - 1) * $Ds + 1) cannot be larger than input + padding ($Is + $Pl + $Ph)!"))
102130
end
103131
end
104-
105-
return pstride, ppadding, pdilation
106-
end
107-
108-
"""
109-
output_size(c::ConvDims)
110-
111-
Calculate the output (spatial) dimensions of the convolution. Get channel count via
112-
`channels_out(c)`, and batch count is unknowable.
113-
"""
114-
function output_size(c::ConvDims)
115-
I = input_size(c)
116-
K = kernel_size(c)
117-
S = stride(c)
118-
P = padding(c)
119-
D = dilation(c)
120-
121-
return ntuple(spatial_dims(c)) do i
122-
return div(I[i] + P[(i-1)*2 + 1] + P[(i-1)*2 + 2] - (K[i] - 1) * D[i] - 1, S[i]) + 1
123-
end
124-
end
125-
126-
# Override show() for these beauties
127-
function Base.show(io::IO, cdims::C) where {C <: ConvDims}
128-
I = (input_size(cdims)..., channels_in(cdims))
129-
O = (output_size(cdims)..., channels_out(cdims))
130-
K = kernel_size(cdims)
131-
S = stride(cdims)
132-
P = padding(cdims)
133-
D = dilation(cdims)
134-
F = flipkernel(cdims)
135-
G = groupcount(cdims)
136-
print(io, "$(basetype(C)): $I * $K -> $O, stride: $S, pad: $P, dil: $D, flip: $F, groups: $G")
132+
nothing
137133
end

0 commit comments

Comments
 (0)