@@ -166,13 +166,16 @@ end
166
166
167
167
# First, we will define mappings from the generic API names to our accelerated backend
168
168
# implementations. For homogeneous-datatype 1, 2 and 3d convolutions, we default to using
169
- # im2col + GEMM. Do so in a loop, here:
169
+ # im2col + GEMM.
170
+ # But we always support a fallback, non-accelerated path, where we use the direct, but
171
+ # slow, implementations. These should not typically be used, hence the `@warn`,
170
172
171
173
# These are the GEMM types we will accelerate with `im2col`
172
174
const G = Union{[x[2 ] for x in gemm_datatype_mappings]. .. }
173
175
174
176
for (front_name, backend, signature) in (
175
- # This maps from public, front-facing name, to internal backend name
177
+ # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
178
+ # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
176
179
(:conv , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
177
180
(:conv , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
178
181
)
213
216
214
217
# im2col-accelerated function forwarding definition
215
218
for (front_name, backend, signature) in (
216
- # This maps from public, front-facing name, to internal backend name
219
+ # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
220
+ # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
217
221
(:∇conv_data , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
218
222
(:∇conv_data , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
219
223
)
@@ -256,13 +260,13 @@ for (front_name, backend, signature) in (
256
260
end
257
261
258
262
for (front_name, backend, signature) in (
259
- # This maps from public, front-facing name, to internal backend name
263
+ # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
264
+ # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
260
265
(:∇conv_filter , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
261
266
(:∇conv_filter , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
262
267
)
263
268
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
264
269
@eval begin
265
- # println($(Symbol(["$(i)" for i in "$(signature[5])"]...))...)
266
270
function $ (Symbol (" $(front_name) !" ))(
267
271
out:: AbstractArray{$(signature[1][1]), $(signature[1][2])} ,
268
272
in1:: AbstractArray{$(signature[2][1]), $(signature[1][2])} ,
@@ -298,98 +302,37 @@ for (front_name, backend, signature) in (
298
302
end
299
303
300
304
301
- for (front_name, backend) in (
302
- # This maps from public, front-facing name, to internal backend name
303
- :depthwiseconv => :im2col ,
304
- :∇depthwiseconv_data => :im2col ,
305
- :∇depthwiseconv_filter => :im2col ,
306
- )
305
+ for (front_name, backend, signature) in (
306
+ # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
307
+ # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
308
+ (:depthwiseconv , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
309
+ (:depthwiseconv , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
310
+
311
+ (:∇depthwiseconv_data , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
312
+ (:∇depthwiseconv_data , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
313
+
314
+ (:∇depthwiseconv_filter , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
315
+ (:∇depthwiseconv_filter , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
316
+ )
307
317
308
318
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
309
319
@eval begin
310
320
# im2col-accelerated function forwarding definition
311
321
function $ (Symbol (" $(front_name) !" ))(
312
- out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
313
- in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: $G , C <: ConvDims }
314
- $ (Symbol (" $(front_name) _$(backend) !" ))(out, in1, in2, cdims; kwargs... )
315
- end
316
- end
317
- end
318
-
319
- # We always support a fallback, non-accelerated path, where we use the direct, but
320
- # slow, implementations. These should not typically be used, hence the `@warn`,
321
- # but let's go ahead and define them first:
322
- for front_name in (:depthwiseconv , :∇depthwiseconv_data , :∇depthwiseconv_filter )
323
- @eval begin
324
- function $ (Symbol (" $(front_name) !" ))(
325
- y:: AbstractArray{yT,N} , in1:: AbstractArray{T1,N} ,
326
- in2:: AbstractArray{T2,N} , cdims:: ConvDims ;
327
- kwargs... ) where {yT, T1, T2, N}
328
- if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
322
+ out:: AbstractArray{$(signature[1][1]), $(signature[1][2])} ,
323
+ in1:: AbstractArray{$(signature[2][1]), $(signature[1][2])} ,
324
+ in2:: AbstractArray{$(signature[3][1]), $(signature[1][2])} ,
325
+ cdims:: $ (signature[4 ]),
326
+ kwargs... ) where {$ (signature[5 ]. .. )}
327
+ if $ (string (backend)) == " direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
329
328
@warn string (" Slow fallback implementation invoked for " , $ (string (front_name)), " ! " ,
330
- " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
329
+ " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
331
330
end
332
- $ (Symbol (" $(front_name) _direct !" ))(y , in1, in2, cdims; kwargs... )
331
+ $ (Symbol (" $(front_name) _ $(backend) !" ))(out , in1, in2, cdims; kwargs... )
333
332
end
334
333
end
335
334
end
336
335
337
- # # direct function forwarding definition
338
- # function ∇conv_data!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
339
- # in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
340
- # if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
341
- # @warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
342
- # "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
343
- # end
344
-
345
- # dx_cs = Iterators.partition(1:size(out, 4),
346
- # channels_in(cdims) ÷ groupcount(cdims))
347
- # w_cs = Iterators.partition(1:size(in2, 5),
348
- # channels_out(cdims) ÷ groupcount(cdims))
349
- # dy_cs = Iterators.partition(1:size(in1, 4),
350
- # channels_out(cdims) ÷ groupcount(cdims))
351
- # cdims2 = basetype(C)(cdims,
352
- # G = 1,
353
- # C_in = channels_in(cdims) ÷ groupcount(cdims),
354
- # C_out = channels_out(cdims) ÷ groupcount(cdims))
355
-
356
- # Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
357
- # dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
358
- # dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
359
- # wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
360
- # Threads.@spawn ∇conv_data_direct!(dxv, dyv, wv, cdims2; kwargs...)
361
- # end
362
-
363
- # return out
364
- # end
365
-
366
- # function ∇conv_filter!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
367
- # in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
368
- # if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
369
- # @warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
370
- # "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
371
- # end
372
- # dw_cs = Iterators.partition(1:size(out, 5),
373
- # channels_out(cdims) ÷ groupcount(cdims))
374
- # dy_cs = Iterators.partition(1:size(in2, 4),
375
- # channels_out(cdims) ÷ groupcount(cdims))
376
- # x_cs = Iterators.partition(1:size(in1, 4),
377
- # channels_in(cdims) ÷ groupcount(cdims))
378
- # cdims2 = basetype(C)(cdims,
379
- # G = 1,
380
- # C_in = channels_in(cdims) ÷ groupcount(cdims),
381
- # C_out = channels_out(cdims) ÷ groupcount(cdims))
382
-
383
- # Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
384
- # x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
385
- # dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
386
- # dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
387
- # Threads.@spawn ∇conv_filter_direct!(dw, x, dy, cdims2; kwargs...)
388
- # end
389
-
390
- # return out
391
- # end
392
-
393
336
for Dims in [:DenseConvDims , :DepthwiseConvDims , :PoolDims ]
394
337
@eval @non_differentiable $ Dims (:: Any... )
395
338
end
0 commit comments