|
292 | 292 | for (front_name, backend) in (
|
293 | 293 | # This maps from public, front-facing name, to internal backend name
|
294 | 294 | :conv => :direct,
|
295 |
| - :∇conv_data => :direct, |
296 |
| - :∇conv_filter => :direct, |
| 295 | + # :∇conv_data => :direct, |
| 296 | + # :∇conv_filter => :direct, |
297 | 297 | )
|
298 | 298 |
|
299 | 299 | # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
|
@@ -329,6 +329,62 @@ for (front_name, backend) in (
|
329 | 329 | end
|
330 | 330 | end
|
331 | 331 |
|
| 332 | +# direct function forwarding definition |
| 333 | +function ∇conv_data!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N}, |
| 334 | + in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims} |
| 335 | + if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual |
| 336 | + @warn string("Slow fallback implementation invoked for ", string(front_name), "! ", |
| 337 | + "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 |
| 338 | + end |
| 339 | + |
| 340 | + dx_cs = Iterators.partition(1:size(out, 4), |
| 341 | + channels_in(cdims) ÷ groupcount(cdims)) |
| 342 | + w_cs = Iterators.partition(1:size(in2, 5), |
| 343 | + channels_out(cdims) ÷ groupcount(cdims)) |
| 344 | + dy_cs = Iterators.partition(1:size(in1, 4), |
| 345 | + channels_out(cdims) ÷ groupcount(cdims)) |
| 346 | + cdims2 = basetype(C)(cdims, |
| 347 | + G = 1, |
| 348 | + C_in = channels_in(cdims) ÷ groupcount(cdims), |
| 349 | + C_out = channels_out(cdims) ÷ groupcount(cdims)) |
| 350 | + |
| 351 | + Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs) |
| 352 | + dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...] |
| 353 | + dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...] |
| 354 | + wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...] |
| 355 | + Threads.@spawn ∇conv_data_direct!(dxv, dyv, wv, cdims2; kwargs...) |
| 356 | + end |
| 357 | + |
| 358 | + return out |
| 359 | +end |
| 360 | + |
| 361 | +function ∇conv_filter!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N}, |
| 362 | + in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims} |
| 363 | + if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual |
| 364 | + @warn string("Slow fallback implementation invoked for ", string(front_name), "! ", |
| 365 | + "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 |
| 366 | + end |
| 367 | + dw_cs = Iterators.partition(1:size(out, 5), |
| 368 | + channels_out(cdims) ÷ groupcount(cdims)) |
| 369 | + dy_cs = Iterators.partition(1:size(in2, 4), |
| 370 | + channels_out(cdims) ÷ groupcount(cdims)) |
| 371 | + x_cs = Iterators.partition(1:size(in1, 4), |
| 372 | + channels_in(cdims) ÷ groupcount(cdims)) |
| 373 | + cdims2 = basetype(C)(cdims, |
| 374 | + G = 1, |
| 375 | + C_in = channels_in(cdims) ÷ groupcount(cdims), |
| 376 | + C_out = channels_out(cdims) ÷ groupcount(cdims)) |
| 377 | + |
| 378 | + Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs) |
| 379 | + x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...] |
| 380 | + dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...] |
| 381 | + dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...] |
| 382 | + Threads.@spawn ∇conv_filter_direct!(dw, x, dy, cdims2; kwargs...) |
| 383 | + end |
| 384 | + |
| 385 | + return out |
| 386 | +end |
| 387 | + |
332 | 388 | for Dims in [:DenseConvDims, :DepthwiseConvDims, :PoolDims]
|
333 | 389 | @eval @non_differentiable $Dims(::Any...)
|
334 | 390 | end
|
|
0 commit comments