Skip to content

Commit 46a52bd

Browse files
conv_data! and conv_filter! separate def
1 parent 6a20153 commit 46a52bd

File tree

1 file changed

+58
-2
lines changed

1 file changed

+58
-2
lines changed

src/conv.jl

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,8 @@ end
292292
for (front_name, backend) in (
293293
# This maps from public, front-facing name, to internal backend name
294294
:conv => :direct,
295-
:∇conv_data => :direct,
296-
:∇conv_filter => :direct,
295+
# :∇conv_data => :direct,
296+
# :∇conv_filter => :direct,
297297
)
298298

299299
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
@@ -329,6 +329,62 @@ for (front_name, backend) in (
329329
end
330330
end
331331

332+
# direct function forwarding definition
333+
function ∇conv_data!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
334+
in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
335+
if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
336+
@warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
337+
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
338+
end
339+
340+
dx_cs = Iterators.partition(1:size(out, 4),
341+
channels_in(cdims) ÷ groupcount(cdims))
342+
w_cs = Iterators.partition(1:size(in2, 5),
343+
channels_out(cdims) ÷ groupcount(cdims))
344+
dy_cs = Iterators.partition(1:size(in1, 4),
345+
channels_out(cdims) ÷ groupcount(cdims))
346+
cdims2 = basetype(C)(cdims,
347+
G = 1,
348+
C_in = channels_in(cdims) ÷ groupcount(cdims),
349+
C_out = channels_out(cdims) ÷ groupcount(cdims))
350+
351+
Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
352+
dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
353+
dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
354+
wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
355+
Threads.@spawn ∇conv_data_direct!(dxv, dyv, wv, cdims2; kwargs...)
356+
end
357+
358+
return out
359+
end
360+
361+
function ∇conv_filter!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
362+
in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
363+
if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
364+
@warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
365+
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
366+
end
367+
dw_cs = Iterators.partition(1:size(out, 5),
368+
channels_out(cdims) ÷ groupcount(cdims))
369+
dy_cs = Iterators.partition(1:size(in2, 4),
370+
channels_out(cdims) ÷ groupcount(cdims))
371+
x_cs = Iterators.partition(1:size(in1, 4),
372+
channels_in(cdims) ÷ groupcount(cdims))
373+
cdims2 = basetype(C)(cdims,
374+
G = 1,
375+
C_in = channels_in(cdims) ÷ groupcount(cdims),
376+
C_out = channels_out(cdims) ÷ groupcount(cdims))
377+
378+
Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
379+
x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
380+
dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
381+
dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
382+
Threads.@spawn ∇conv_filter_direct!(dw, x, dy, cdims2; kwargs...)
383+
end
384+
385+
return out
386+
end
387+
332388
for Dims in [:DenseConvDims, :DepthwiseConvDims, :PoolDims]
333389
@eval @non_differentiable $Dims(::Any...)
334390
end

0 commit comments

Comments
 (0)