Skip to content

Commit cde10cf

Browse files
Added better explantaion on merged implemenntation
1 parent 9987269 commit cde10cf

File tree

1 file changed

+29
-86
lines changed

1 file changed

+29
-86
lines changed

src/conv.jl

Lines changed: 29 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,16 @@ end
166166

167167
# First, we will define mappings from the generic API names to our accelerated backend
168168
# implementations. For homogeneous-datatype 1, 2 and 3d convolutions, we default to using
169-
# im2col + GEMM. Do so in a loop, here:
169+
# im2col + GEMM.
170+
# But we always support a fallback, non-accelerated path, where we use the direct, but
171+
# slow, implementations. These should not typically be used, hence the `@warn`,
170172

171173
# These are the GEMM types we will accelerate with `im2col`
172174
const G = Union{[x[2] for x in gemm_datatype_mappings]...}
173175

174176
for (front_name, backend, signature) in (
175-
# This maps from public, front-facing name, to internal backend name
177+
# This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
178+
# (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
176179
(:conv, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
177180
(:conv, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
178181
)
@@ -213,7 +216,8 @@ end
213216

214217
# im2col-accelerated function forwarding definition
215218
for (front_name, backend, signature) in (
216-
# This maps from public, front-facing name, to internal backend name
219+
# This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
220+
# (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
217221
(:∇conv_data, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
218222
(:∇conv_data, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
219223
)
@@ -256,13 +260,13 @@ for (front_name, backend, signature) in (
256260
end
257261

258262
for (front_name, backend, signature) in (
259-
# This maps from public, front-facing name, to internal backend name
263+
# This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
264+
# (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
260265
(:∇conv_filter, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
261266
(:∇conv_filter, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
262267
)
263268
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
264269
@eval begin
265-
# println($(Symbol(["$(i)" for i in "$(signature[5])"]...))...)
266270
function $(Symbol("$(front_name)!"))(
267271
out::AbstractArray{$(signature[1][1]), $(signature[1][2])},
268272
in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},
@@ -298,98 +302,37 @@ for (front_name, backend, signature) in (
298302
end
299303

300304

301-
for (front_name, backend) in (
302-
# This maps from public, front-facing name, to internal backend name
303-
:depthwiseconv => :im2col,
304-
:∇depthwiseconv_data => :im2col,
305-
:∇depthwiseconv_filter => :im2col,
306-
)
305+
for (front_name, backend, signature) in (
306+
# This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
307+
# (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
308+
(:depthwiseconv, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
309+
(:depthwiseconv, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
310+
311+
(:∇depthwiseconv_data, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
312+
(:∇depthwiseconv_data, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
313+
314+
(:∇depthwiseconv_filter, :im2col, ((:T, 5), (:T, 5), (:T, 5), :C, (:(T <: G), :(C <: ConvDims)))),
315+
(:∇depthwiseconv_filter, :direct, ((:yT, :N), (:T1, :N), (:T2, :N), :C, (:yT, :T1, :T2, :N, :(C <: ConvDims)))),
316+
)
307317

308318
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
309319
@eval begin
310320
# im2col-accelerated function forwarding definition
311321
function $(Symbol("$(front_name)!"))(
312-
out::AbstractArray{T,5}, in1::AbstractArray{T,5},
313-
in2::AbstractArray{T,5}, cdims::C; kwargs...) where {T <: $G, C <: ConvDims}
314-
$(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...)
315-
end
316-
end
317-
end
318-
319-
# We always support a fallback, non-accelerated path, where we use the direct, but
320-
# slow, implementations. These should not typically be used, hence the `@warn`,
321-
# but let's go ahead and define them first:
322-
for front_name in (:depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter)
323-
@eval begin
324-
function $(Symbol("$(front_name)!"))(
325-
y::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
326-
in2::AbstractArray{T2,N}, cdims::ConvDims;
327-
kwargs...) where {yT, T1, T2, N}
328-
if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
322+
out::AbstractArray{$(signature[1][1]), $(signature[1][2])},
323+
in1::AbstractArray{$(signature[2][1]), $(signature[1][2])},
324+
in2::AbstractArray{$(signature[3][1]), $(signature[1][2])},
325+
cdims::$(signature[4]),
326+
kwargs...) where {$(signature[5]...)}
327+
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
329328
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
330-
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
329+
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
331330
end
332-
$(Symbol("$(front_name)_direct!"))(y, in1, in2, cdims; kwargs...)
331+
$(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...)
333332
end
334333
end
335334
end
336335

337-
# # direct function forwarding definition
338-
# function ∇conv_data!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
339-
# in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
340-
# if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
341-
# @warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
342-
# "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
343-
# end
344-
345-
# dx_cs = Iterators.partition(1:size(out, 4),
346-
# channels_in(cdims) ÷ groupcount(cdims))
347-
# w_cs = Iterators.partition(1:size(in2, 5),
348-
# channels_out(cdims) ÷ groupcount(cdims))
349-
# dy_cs = Iterators.partition(1:size(in1, 4),
350-
# channels_out(cdims) ÷ groupcount(cdims))
351-
# cdims2 = basetype(C)(cdims,
352-
# G = 1,
353-
# C_in = channels_in(cdims) ÷ groupcount(cdims),
354-
# C_out = channels_out(cdims) ÷ groupcount(cdims))
355-
356-
# Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
357-
# dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
358-
# dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
359-
# wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
360-
# Threads.@spawn ∇conv_data_direct!(dxv, dyv, wv, cdims2; kwargs...)
361-
# end
362-
363-
# return out
364-
# end
365-
366-
# function ∇conv_filter!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
367-
# in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
368-
# if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
369-
# @warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
370-
# "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
371-
# end
372-
# dw_cs = Iterators.partition(1:size(out, 5),
373-
# channels_out(cdims) ÷ groupcount(cdims))
374-
# dy_cs = Iterators.partition(1:size(in2, 4),
375-
# channels_out(cdims) ÷ groupcount(cdims))
376-
# x_cs = Iterators.partition(1:size(in1, 4),
377-
# channels_in(cdims) ÷ groupcount(cdims))
378-
# cdims2 = basetype(C)(cdims,
379-
# G = 1,
380-
# C_in = channels_in(cdims) ÷ groupcount(cdims),
381-
# C_out = channels_out(cdims) ÷ groupcount(cdims))
382-
383-
# Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
384-
# x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
385-
# dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
386-
# dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
387-
# Threads.@spawn ∇conv_filter_direct!(dw, x, dy, cdims2; kwargs...)
388-
# end
389-
390-
# return out
391-
# end
392-
393336
for Dims in [:DenseConvDims, :DepthwiseConvDims, :PoolDims]
394337
@eval @non_differentiable $Dims(::Any...)
395338
end

0 commit comments

Comments
 (0)