274
274
# We always support a fallback, non-accelerated path, where we use the direct, but
275
275
# slow, implementations. These should not typically be used, hence the `@warn`,
276
276
# but let's go ahead and define them first:
277
- for front_name in (:conv , :∇conv_data , :∇conv_filter ,
278
- :depthwiseconv , :∇depthwiseconv_data , :∇depthwiseconv_filter )
277
+ for front_name in (:depthwiseconv , :∇depthwiseconv_data , :∇depthwiseconv_filter )
279
278
@eval begin
280
279
function $ (Symbol (" $(front_name) !" ))(
281
280
y:: AbstractArray{yT,N} , in1:: AbstractArray{T1,N} ,
@@ -290,6 +289,46 @@ for front_name in (:conv, :∇conv_data, :∇conv_filter,
290
289
end
291
290
end
292
291
292
+ for (front_name, backend) in (
293
+ # This maps from public, front-facing name, to internal backend name
294
+ :conv => :direct ,
295
+ :∇conv_data => :direct ,
296
+ :∇conv_filter => :direct ,
297
+ )
298
+
299
+ # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
300
+ @eval begin
301
+ # im2col-accelerated function forwarding definition
302
+ function $ (Symbol (" $(front_name) !" ))(
303
+ out:: AbstractArray{yT,N} , in1:: AbstractArray{T1,N} ,
304
+ in2:: AbstractArray{T2,N} , cdims:: C ;
305
+ kwargs... ) where {yT, T1, T2, N, C <: ConvDims }
306
+ if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
307
+ @warn string (" Slow fallback implementation invoked for " , $ (string (front_name)), " ! " ,
308
+ " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
309
+ end
310
+
311
+ x_cs = Iterators. partition (1 : size (in1, 4 ),
312
+ channels_in (cdims) ÷ groupcount (cdims))
313
+ w_cs = Iterators. partition (1 : size (in2, 5 ),
314
+ channels_out (cdims) ÷ groupcount (cdims))
315
+ cdims2 = basetype (C)(cdims,
316
+ G = 1 ,
317
+ C_in = channels_in (cdims) ÷ groupcount (cdims),
318
+ C_out = channels_out (cdims) ÷ groupcount (cdims))
319
+
320
+ Threads. @sync for (xc, wc) in zip (x_cs, w_cs)
321
+ x = @view in1[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
322
+ w = @view in2[ntuple (i -> i == 5 ? wc : Colon (), 5 )... ]
323
+ y = @view out[ntuple (i -> i == 4 ? wc : Colon (), 5 )... ]
324
+ Threads. @spawn $ (Symbol (" $(front_name) _$(backend) !" ))(y, x, w, cdims2; kwargs... )
325
+ end
326
+
327
+ return out
328
+ end
329
+ end
330
+ end
331
+
293
332
for Dims in [:DenseConvDims , :DepthwiseConvDims , :PoolDims ]
294
333
@eval @non_differentiable $ Dims (:: Any... )
295
334
end
0 commit comments