Skip to content

Commit 0902922

Browse files
conv_direct functions w/ correct groups partition
1 parent 0570f6e commit 0902922

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

src/conv.jl

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,7 @@ end
274274
# We always support a fallback, non-accelerated path, where we use the direct, but
275275
# slow, implementations. These should not typically be used, hence the `@warn`,
276276
# but let's go ahead and define them first:
277-
for front_name in (:conv, :∇conv_data, :∇conv_filter,
278-
:depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter)
277+
for front_name in (:depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter)
279278
@eval begin
280279
function $(Symbol("$(front_name)!"))(
281280
y::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
@@ -290,6 +289,46 @@ for front_name in (:conv, :∇conv_data, :∇conv_filter,
290289
end
291290
end
292291

292+
for (front_name, backend) in (
293+
# This maps from public, front-facing name, to internal backend name
294+
:conv => :direct,
295+
:∇conv_data => :direct,
296+
:∇conv_filter => :direct,
297+
)
298+
299+
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
300+
@eval begin
301+
# im2col-accelerated function forwarding definition
302+
function $(Symbol("$(front_name)!"))(
303+
out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
304+
in2::AbstractArray{T2,N}, cdims::C;
305+
kwargs...) where {yT, T1, T2, N, C <: ConvDims}
306+
if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
307+
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
308+
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
309+
end
310+
311+
x_cs = Iterators.partition(1:size(in1, 4),
312+
channels_in(cdims) ÷ groupcount(cdims))
313+
w_cs = Iterators.partition(1:size(in2, 5),
314+
channels_out(cdims) ÷ groupcount(cdims))
315+
cdims2 = basetype(C)(cdims,
316+
G = 1,
317+
C_in = channels_in(cdims) ÷ groupcount(cdims),
318+
C_out = channels_out(cdims) ÷ groupcount(cdims))
319+
320+
Threads.@sync for (xc, wc) in zip(x_cs, w_cs)
321+
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
322+
w = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
323+
y = @view out[ntuple(i -> i == 4 ? wc : Colon(), 5)...]
324+
Threads.@spawn $(Symbol("$(front_name)_$(backend)!"))(y, x, w, cdims2; kwargs...)
325+
end
326+
327+
return out
328+
end
329+
end
330+
end
331+
293332
for Dims in [:DenseConvDims, :DepthwiseConvDims, :PoolDims]
294333
@eval @non_differentiable $Dims(::Any...)
295334
end

test/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ end
725725
@test size(conv(x, w; stride = (1, 2), pad = (2, 3), dilation = (2, 2), flipped = true)) == (12, 7, 16, 10)
726726
end
727727

728-
# https://github.com/FluxML/NNlib.jl/pull/171
728+
# https://github.com/FluxML/NNlib.jl/issues/369
729729
@testset "conv_wrapper with groups - not equal types that trigger direct backend" begin
730730
x = rand(Float32, 10, 10, 32, 8)
731731
w = rand(Float64, 2, 2, 16, 4)

0 commit comments

Comments
 (0)