Skip to content

Commit e566635

Browse files
authored
Merge pull request #468 from gabrielpreviato/conv-direct-groups
Fix conv with groups when falling in direct backend
2 parents 6518f40 + f7b59af commit e566635

File tree

2 files changed

+132
-79
lines changed

2 files changed

+132
-79
lines changed

src/conv.jl

Lines changed: 121 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -166,31 +166,41 @@ end
166166

167167
# First, we will define mappings from the generic API names to our accelerated backend
168168
# implementations. For homogeneous-datatype 1, 2 and 3d convolutions, we default to using
169-
# im2col + GEMM. Do so in a loop, here:
169+
# im2col + GEMM.
170+
# But we always support a fallback, non-accelerated path, where we use the direct, but
171+
# slow, implementations. These should not typically be used, hence the `@warn`,
170172

171173
# These are the GEMM types we will accelerate with `im2col`
172174
const G = Union{[x[2] for x in gemm_datatype_mappings]...}
173175

174-
for (front_name, backend) in (
175-
# This maps from public, front-facing name, to internal backend name
176-
:conv => :im2col,
177-
)
178-
176+
for (front_name, backend, signature) in (
177+
# This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
178+
# (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
179+
(:conv, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
180+
(:conv, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
181+
)
179182
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
180183
@eval begin
181-
# im2col-accelerated function forwarding definition
184+
182185
function $(Symbol("$(front_name)!"))(
183-
out::AbstractArray{T,5}, in1::AbstractArray{T,5},
184-
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: $G, C <: ConvDims}
186+
out::AbstractArray{$(signature[1][1]), $(signature[1][2])},
187+
in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},
188+
in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},
189+
cdims::$(signature[4]);
190+
kwargs...) where {$(signature[5]...)}
191+
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
192+
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
193+
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
194+
end
185195

186196
x_cs = Iterators.partition(1:size(in1, 4),
187-
channels_in(cdims) ÷ groupcount(cdims))
197+
channels_in(cdims) ÷ groupcount(cdims))
188198
w_cs = Iterators.partition(1:size(in2, 5),
189-
channels_out(cdims) ÷ groupcount(cdims))
199+
channels_out(cdims) ÷ groupcount(cdims))
190200
cdims2 = basetype(C)(cdims,
191-
G = 1,
192-
C_in = channels_in(cdims) ÷ groupcount(cdims),
193-
C_out = channels_out(cdims) ÷ groupcount(cdims))
201+
G = 1,
202+
C_in = channels_in(cdims) ÷ groupcount(cdims),
203+
C_out = channels_out(cdims) ÷ groupcount(cdims))
194204

195205
Threads.@sync for (xc, wc) in zip(x_cs, w_cs)
196206
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
@@ -205,87 +215,119 @@ for (front_name, backend) in (
205215
end
206216

207217
# im2col-accelerated function forwarding definition
208-
function ∇conv_data!(out::AbstractArray{T,5}, in1::AbstractArray{T,5},
209-
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: G, C <: ConvDims}
210-
211-
dx_cs = Iterators.partition(1:size(out, 4),
212-
channels_in(cdims) ÷ groupcount(cdims))
213-
w_cs = Iterators.partition(1:size(in2, 5),
214-
channels_out(cdims) ÷ groupcount(cdims))
215-
dy_cs = Iterators.partition(1:size(in1, 4),
216-
channels_out(cdims) ÷ groupcount(cdims))
217-
cdims2 = basetype(C)(cdims,
218-
G = 1,
219-
C_in = channels_in(cdims) ÷ groupcount(cdims),
220-
C_out = channels_out(cdims) ÷ groupcount(cdims))
221-
222-
Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
223-
dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
224-
dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
225-
wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
226-
Threads.@spawn ∇conv_data_im2col!(dxv, dyv, wv, cdims2; kwargs...)
227-
end
218+
for (front_name, backend, signature) in (
219+
# This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
220+
# (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
221+
(:∇conv_data, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
222+
(:∇conv_data, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
223+
)
224+
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
225+
@eval begin
226+
function $(Symbol("$(front_name)!"))(
227+
out::AbstractArray{$(signature[1][1]), $(signature[1][2])},
228+
in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},
229+
in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},
230+
cdims::$(signature[4]);
231+
kwargs...) where {$(signature[5]...)}
232+
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
233+
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
234+
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
235+
end
228236

229-
return out
230-
end
231237

232-
function ∇conv_filter!(out::AbstractArray{T,5}, in1::AbstractArray{T,5},
233-
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: G, C <: ConvDims}
234-
dw_cs = Iterators.partition(1:size(out, 5),
235-
channels_out(cdims) ÷ groupcount(cdims))
236-
dy_cs = Iterators.partition(1:size(in2, 4),
237-
channels_out(cdims) ÷ groupcount(cdims))
238-
x_cs = Iterators.partition(1:size(in1, 4),
239-
channels_in(cdims) ÷ groupcount(cdims))
240-
cdims2 = basetype(C)(cdims,
241-
G = 1,
242-
C_in = channels_in(cdims) ÷ groupcount(cdims),
243-
C_out = channels_out(cdims) ÷ groupcount(cdims))
244-
245-
Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
246-
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
247-
dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
248-
dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
249-
Threads.@spawn ∇conv_filter_im2col!(dw, x, dy, cdims2; kwargs...)
250-
end
238+
dx_cs = Iterators.partition(1:size(out, 4),
239+
channels_in(cdims) ÷ groupcount(cdims))
240+
w_cs = Iterators.partition(1:size(in2, 5),
241+
channels_out(cdims) ÷ groupcount(cdims))
242+
dy_cs = Iterators.partition(1:size(in1, 4),
243+
channels_out(cdims) ÷ groupcount(cdims))
244+
cdims2 = basetype(C)(cdims,
245+
G = 1,
246+
C_in = channels_in(cdims) ÷ groupcount(cdims),
247+
C_out = channels_out(cdims) ÷ groupcount(cdims))
248+
249+
Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
250+
dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
251+
dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
252+
wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
253+
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(dxv, dyv, wv, cdims2; kwargs...)
254+
end
251255

252-
return out
256+
return out
257+
end
258+
end
253259
end
254260

255-
256-
for (front_name, backend) in (
257-
# This maps from public, front-facing name, to internal backend name
258-
:depthwiseconv => :im2col,
259-
:∇depthwiseconv_data => :im2col,
260-
:∇depthwiseconv_filter => :im2col,
261-
)
262-
261+
for (front_name, backend, signature) in (
262+
# This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
263+
# (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
264+
(:∇conv_filter, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
265+
(:∇conv_filter, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
266+
)
263267
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
264268
@eval begin
265-
# im2col-accelerated function forwarding definition
266269
function $(Symbol("$(front_name)!"))(
267-
out::AbstractArray{T,5}, in1::AbstractArray{T,5},
268-
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: $G, C <: ConvDims}
269-
$(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...)
270+
out::AbstractArray{$(signature[1][1]), $(signature[1][2])},
271+
in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},
272+
in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},
273+
cdims::$(signature[4]);
274+
kwargs...) where {$(signature[5]...)}
275+
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
276+
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
277+
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
278+
end
279+
280+
dw_cs = Iterators.partition(1:size(out, 5),
281+
channels_out(cdims) ÷ groupcount(cdims))
282+
dy_cs = Iterators.partition(1:size(in2, 4),
283+
channels_out(cdims) ÷ groupcount(cdims))
284+
x_cs = Iterators.partition(1:size(in1, 4),
285+
channels_in(cdims) ÷ groupcount(cdims))
286+
cdims2 = basetype(C)(cdims,
287+
G = 1,
288+
C_in = channels_in(cdims) ÷ groupcount(cdims),
289+
C_out = channels_out(cdims) ÷ groupcount(cdims))
290+
291+
Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
292+
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
293+
dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
294+
dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
295+
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(dw, x, dy, cdims2; kwargs...)
296+
end
297+
298+
return out
270299
end
271300
end
272301
end
273302

274-
# We always support a fallback, non-accelerated path, where we use the direct, but
275-
# slow, implementations. These should not typically be used, hence the `@warn`,
276-
# but let's go ahead and define them first:
277-
for front_name in (:conv, :∇conv_data, :∇conv_filter,
278-
:depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter)
303+
304+
for (front_name, backend, signature) in (
305+
# This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
306+
# (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
307+
(:depthwiseconv, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
308+
(:depthwiseconv, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
309+
310+
(:∇depthwiseconv_data, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
311+
(:∇depthwiseconv_data, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
312+
313+
(:∇depthwiseconv_filter, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
314+
(:∇depthwiseconv_filter, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
315+
)
316+
317+
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
279318
@eval begin
319+
# im2col-accelerated function forwarding definition
280320
function $(Symbol("$(front_name)!"))(
281-
y::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
282-
in2::AbstractArray{T2,N}, cdims::ConvDims;
283-
kwargs...) where {yT, T1, T2, N}
284-
if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
321+
out::AbstractArray{$(signature[1][1]), $(signature[1][2])},
322+
in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},
323+
in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},
324+
cdims::$(signature[4]);
325+
kwargs...) where {$(signature[5]...)}
326+
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
285327
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
286-
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
328+
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
287329
end
288-
$(Symbol("$(front_name)_direct!"))(y, in1, in2, cdims; kwargs...)
330+
$(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...)
289331
end
290332
end
291333
end

test/conv.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,17 @@ end
725725
@test size(conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true)) == (12, 7, 16, 10)
726726
end
727727

728+
# https://github.com/FluxML/NNlib.jl/issues/369
729+
@testset "conv_wrapper with groups - not equal types that trigger direct backend" begin
730+
x = rand(Float32, 10, 10, 32, 8)
731+
w = rand(Float64, 2, 2, 16, 4)
732+
g = 2
733+
@test conv(x, w; groups=g) conv(x, Float32.(w); groups=g)
734+
@test conv(x, w; stride = (2, 2), pad = (2, 2), groups=g) conv(x, w; stride = (2, 2), pad = (2, 2), groups=g)
735+
@test conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), groups=g) conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), groups=g)
736+
@test conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true, groups=g) conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true, groups=g)
737+
end
738+
728739
@testset "depthwiseconv_wrapper" begin
729740
x = rand(10, 10, 3, 10)
730741
w = rand(2, 2, 3, 3)

0 commit comments

Comments
 (0)