Skip to content

Commit 2f4725a

Browse files
Merging im2col and direct backend implementations
1 parent 46a52bd commit 2f4725a

File tree

1 file changed

+150
-145
lines changed

1 file changed

+150
-145
lines changed

src/conv.jl

Lines changed: 150 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -171,26 +171,33 @@ end
171171
# These are the GEMM types we will accelerate with `im2col`
172172
const G = Union{[x[2] for x in gemm_datatype_mappings]...}
173173

174-
for (front_name, backend) in (
175-
# This maps from public, front-facing name, to internal backend name
176-
:conv => :im2col,
177-
)
178-
174+
for (front_name, backend, signature) in (
175+
# This maps from public, front-facing name, to internal backend name
176+
(:conv, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
177+
(:conv, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
178+
)
179179
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
180180
@eval begin
181-
# im2col-accelerated function forwarding definition
181+
182182
function $(Symbol("$(front_name)!"))(
183-
out::AbstractArray{T,5}, in1::AbstractArray{T,5},
184-
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: $G, C <: ConvDims}
183+
out::AbstractArray{$(signature[1][1]), $(signature[1][2])},
184+
in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},
185+
in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},
186+
cdims::$(signature[4]),
187+
kwargs...) where {$(signature[5]...)}
188+
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
189+
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
190+
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
191+
end
185192

186193
x_cs = Iterators.partition(1:size(in1, 4),
187-
channels_in(cdims) ÷ groupcount(cdims))
194+
channels_in(cdims) ÷ groupcount(cdims))
188195
w_cs = Iterators.partition(1:size(in2, 5),
189-
channels_out(cdims) ÷ groupcount(cdims))
196+
channels_out(cdims) ÷ groupcount(cdims))
190197
cdims2 = basetype(C)(cdims,
191-
G = 1,
192-
C_in = channels_in(cdims) ÷ groupcount(cdims),
193-
C_out = channels_out(cdims) ÷ groupcount(cdims))
198+
G = 1,
199+
C_in = channels_in(cdims) ÷ groupcount(cdims),
200+
C_out = channels_out(cdims) ÷ groupcount(cdims))
194201

195202
Threads.@sync for (xc, wc) in zip(x_cs, w_cs)
196203
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
@@ -205,51 +212,89 @@ for (front_name, backend) in (
205212
end
206213

207214
# im2col-accelerated function forwarding definition
208-
function ∇conv_data!(out::AbstractArray{T,5}, in1::AbstractArray{T,5},
209-
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: G, C <: ConvDims}
210-
211-
dx_cs = Iterators.partition(1:size(out, 4),
212-
channels_in(cdims) ÷ groupcount(cdims))
213-
w_cs = Iterators.partition(1:size(in2, 5),
214-
channels_out(cdims) ÷ groupcount(cdims))
215-
dy_cs = Iterators.partition(1:size(in1, 4),
216-
channels_out(cdims) ÷ groupcount(cdims))
217-
cdims2 = basetype(C)(cdims,
218-
G = 1,
219-
C_in = channels_in(cdims) ÷ groupcount(cdims),
220-
C_out = channels_out(cdims) ÷ groupcount(cdims))
221-
222-
Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
223-
dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
224-
dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
225-
wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
226-
Threads.@spawn ∇conv_data_im2col!(dxv, dyv, wv, cdims2; kwargs...)
227-
end
215+
for (front_name, backend, signature) in (
216+
# This maps from public, front-facing name, to internal backend name
217+
(:∇conv_data, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
218+
(:∇conv_data, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
219+
)
220+
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
221+
@eval begin
222+
# println($(Symbol(["$(i)" for i in "$(signature[5])"]...))...)
223+
function $(Symbol("$(front_name)!"))(
224+
out::AbstractArray{$(signature[1][1]), $(signature[1][2])},
225+
in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},
226+
in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},
227+
cdims::$(signature[4]),
228+
kwargs...) where {$(signature[5]...)}
229+
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
230+
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
231+
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
232+
end
228233

229-
return out
230-
end
231234

232-
function ∇conv_filter!(out::AbstractArray{T,5}, in1::AbstractArray{T,5},
233-
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: G, C <: ConvDims}
234-
dw_cs = Iterators.partition(1:size(out, 5),
235-
channels_out(cdims) ÷ groupcount(cdims))
236-
dy_cs = Iterators.partition(1:size(in2, 4),
237-
channels_out(cdims) ÷ groupcount(cdims))
238-
x_cs = Iterators.partition(1:size(in1, 4),
239-
channels_in(cdims) ÷ groupcount(cdims))
240-
cdims2 = basetype(C)(cdims,
241-
G = 1,
242-
C_in = channels_in(cdims) ÷ groupcount(cdims),
243-
C_out = channels_out(cdims) ÷ groupcount(cdims))
244-
245-
Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
246-
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
247-
dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
248-
dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
249-
Threads.@spawn ∇conv_filter_im2col!(dw, x, dy, cdims2; kwargs...)
235+
dx_cs = Iterators.partition(1:size(out, 4),
236+
channels_in(cdims) ÷ groupcount(cdims))
237+
w_cs = Iterators.partition(1:size(in2, 5),
238+
channels_out(cdims) ÷ groupcount(cdims))
239+
dy_cs = Iterators.partition(1:size(in1, 4),
240+
channels_out(cdims) ÷ groupcount(cdims))
241+
cdims2 = basetype(C)(cdims,
242+
G = 1,
243+
C_in = channels_in(cdims) ÷ groupcount(cdims),
244+
C_out = channels_out(cdims) ÷ groupcount(cdims))
245+
246+
Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
247+
dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
248+
dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
249+
wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
250+
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(dxv, dyv, wv, cdims2; kwargs...)
251+
end
252+
253+
return out
254+
end
250255
end
256+
end
251257

252-
return out
258+
for (front_name, backend, signature) in (
259+
# This maps from public, front-facing name, to internal backend name
260+
(:∇conv_filter, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
261+
(:∇conv_filter, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
262+
)
263+
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
264+
@eval begin
265+
# println($(Symbol(["$(i)" for i in "$(signature[5])"]...))...)
266+
function $(Symbol("$(front_name)!"))(
267+
out::AbstractArray{$(signature[1][1]), $(signature[1][2])},
268+
in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},
269+
in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},
270+
cdims::$(signature[4]),
271+
kwargs...) where {$(signature[5]...)}
272+
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
273+
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
274+
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
275+
end
276+
277+
dw_cs = Iterators.partition(1:size(out, 5),
278+
channels_out(cdims) ÷ groupcount(cdims))
279+
dy_cs = Iterators.partition(1:size(in2, 4),
280+
channels_out(cdims) ÷ groupcount(cdims))
281+
x_cs = Iterators.partition(1:size(in1, 4),
282+
channels_in(cdims) ÷ groupcount(cdims))
283+
cdims2 = basetype(C)(cdims,
284+
G = 1,
285+
C_in = channels_in(cdims) ÷ groupcount(cdims),
286+
C_out = channels_out(cdims) ÷ groupcount(cdims))
287+
288+
Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
289+
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
290+
dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
291+
dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
292+
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(dw, x, dy, cdims2; kwargs...)
293+
end
294+
295+
return out
296+
end
297+
end
253298
end
254299

255300

@@ -289,101 +334,61 @@ for front_name in (:depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_fil
289334
end
290335
end
291336

292-
for (front_name, backend) in (
293-
# This maps from public, front-facing name, to internal backend name
294-
:conv => :direct,
295-
# :∇conv_data => :direct,
296-
# :∇conv_filter => :direct,
297-
)
298-
299-
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
300-
@eval begin
301-
# im2col-accelerated function forwarding definition
302-
function $(Symbol("$(front_name)!"))(
303-
out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
304-
in2::AbstractArray{T2,N}, cdims::C;
305-
kwargs...) where {yT, T1, T2, N, C <: ConvDims}
306-
if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
307-
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
308-
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
309-
end
310-
311-
x_cs = Iterators.partition(1:size(in1, 4),
312-
channels_in(cdims) ÷ groupcount(cdims))
313-
w_cs = Iterators.partition(1:size(in2, 5),
314-
channels_out(cdims) ÷ groupcount(cdims))
315-
cdims2 = basetype(C)(cdims,
316-
G = 1,
317-
C_in = channels_in(cdims) ÷ groupcount(cdims),
318-
C_out = channels_out(cdims) ÷ groupcount(cdims))
319-
320-
Threads.@sync for (xc, wc) in zip(x_cs, w_cs)
321-
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
322-
w = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
323-
y = @view out[ntuple(i -> i == 4 ? wc : Colon(), 5)...]
324-
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(y, x, w, cdims2; kwargs...)
325-
end
326-
327-
return out
328-
end
329-
end
330-
end
331-
332-
# direct function forwarding definition
333-
function ∇conv_data!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
334-
in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
335-
if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
336-
@warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
337-
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
338-
end
337+
# # direct function forwarding definition
338+
# function ∇conv_data!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
339+
# in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
340+
# if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
341+
# @warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
342+
# "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
343+
# end
339344

340-
dx_cs = Iterators.partition(1:size(out, 4),
341-
channels_in(cdims) ÷ groupcount(cdims))
342-
w_cs = Iterators.partition(1:size(in2, 5),
343-
channels_out(cdims) ÷ groupcount(cdims))
344-
dy_cs = Iterators.partition(1:size(in1, 4),
345-
channels_out(cdims) ÷ groupcount(cdims))
346-
cdims2 = basetype(C)(cdims,
347-
G = 1,
348-
C_in = channels_in(cdims) ÷ groupcount(cdims),
349-
C_out = channels_out(cdims) ÷ groupcount(cdims))
350-
351-
Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
352-
dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
353-
dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
354-
wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
355-
Threads.@spawn ∇conv_data_direct!(dxv, dyv, wv, cdims2; kwargs...)
356-
end
345+
# dx_cs = Iterators.partition(1:size(out, 4),
346+
# channels_in(cdims) ÷ groupcount(cdims))
347+
# w_cs = Iterators.partition(1:size(in2, 5),
348+
# channels_out(cdims) ÷ groupcount(cdims))
349+
# dy_cs = Iterators.partition(1:size(in1, 4),
350+
# channels_out(cdims) ÷ groupcount(cdims))
351+
# cdims2 = basetype(C)(cdims,
352+
# G = 1,
353+
# C_in = channels_in(cdims) ÷ groupcount(cdims),
354+
# C_out = channels_out(cdims) ÷ groupcount(cdims))
355+
356+
# Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
357+
# dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
358+
# dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
359+
# wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
360+
# Threads.@spawn ∇conv_data_direct!(dxv, dyv, wv, cdims2; kwargs...)
361+
# end
357362

358-
return out
359-
end
363+
# return out
364+
# end
360365

361-
function ∇conv_filter!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
362-
in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
363-
if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
364-
@warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
365-
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
366-
end
367-
dw_cs = Iterators.partition(1:size(out, 5),
368-
channels_out(cdims) ÷ groupcount(cdims))
369-
dy_cs = Iterators.partition(1:size(in2, 4),
370-
channels_out(cdims) ÷ groupcount(cdims))
371-
x_cs = Iterators.partition(1:size(in1, 4),
372-
channels_in(cdims) ÷ groupcount(cdims))
373-
cdims2 = basetype(C)(cdims,
374-
G = 1,
375-
C_in = channels_in(cdims) ÷ groupcount(cdims),
376-
C_out = channels_out(cdims) ÷ groupcount(cdims))
377-
378-
Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
379-
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
380-
dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
381-
dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
382-
Threads.@spawn ∇conv_filter_direct!(dw, x, dy, cdims2; kwargs...)
383-
end
366+
# function ∇conv_filter!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
367+
# in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
368+
# if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
369+
# @warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
370+
# "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
371+
# end
372+
# dw_cs = Iterators.partition(1:size(out, 5),
373+
# channels_out(cdims) ÷ groupcount(cdims))
374+
# dy_cs = Iterators.partition(1:size(in2, 4),
375+
# channels_out(cdims) ÷ groupcount(cdims))
376+
# x_cs = Iterators.partition(1:size(in1, 4),
377+
# channels_in(cdims) ÷ groupcount(cdims))
378+
# cdims2 = basetype(C)(cdims,
379+
# G = 1,
380+
# C_in = channels_in(cdims) ÷ groupcount(cdims),
381+
# C_out = channels_out(cdims) ÷ groupcount(cdims))
382+
383+
# Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
384+
# x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
385+
# dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
386+
# dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
387+
# Threads.@spawn ∇conv_filter_direct!(dw, x, dy, cdims2; kwargs...)
388+
# end
384389

385-
return out
386-
end
390+
# return out
391+
# end
387392

388393
for Dims in [:DenseConvDims, :DepthwiseConvDims, :PoolDims]
389394
@eval @non_differentiable $Dims(::Any...)

0 commit comments

Comments
 (0)