@@ -171,26 +171,33 @@ end
171
171
# These are the GEMM types we will accelerate with `im2col`
172
172
const G = Union{[x[2 ] for x in gemm_datatype_mappings]. .. }
173
173
174
- for (front_name, backend) in (
175
- # This maps from public, front-facing name, to internal backend name
176
- :conv => :im2col ,
177
- )
178
-
174
+ for (front_name, backend, signature ) in (
175
+ # This maps from public, front-facing name, to internal backend name
176
+ ( :conv , :im2col , (( :T , 5 ), ( :T , 5 ), ( :T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))) ,
177
+ ( :conv , :direct , (( :yT , :N ), ( :T1 , :N ), ( :T2 , :N ), :C , ( :yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
178
+ )
179
179
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
180
180
@eval begin
181
- # im2col-accelerated function forwarding definition
181
+
182
182
function $ (Symbol (" $(front_name) !" ))(
183
- out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
184
- in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: $G , C <: ConvDims }
183
+ out:: AbstractArray{$(signature[1][1]), $(signature[1][2])} ,
184
+ in1:: AbstractArray{$(signature[2][1]), $(signature[1][2])} ,
185
+ in2:: AbstractArray{$(signature[3][1]), $(signature[1][2])} ,
186
+ cdims:: $ (signature[4 ]),
187
+ kwargs... ) where {$ (signature[5 ]. .. )}
188
+ if $ (string (backend)) == " direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
189
+ @warn string (" Slow fallback implementation invoked for " , $ (string (front_name)), " ! " ,
190
+ " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
191
+ end
185
192
186
193
x_cs = Iterators. partition (1 : size (in1, 4 ),
187
- channels_in (cdims) ÷ groupcount (cdims))
194
+ channels_in (cdims) ÷ groupcount (cdims))
188
195
w_cs = Iterators. partition (1 : size (in2, 5 ),
189
- channels_out (cdims) ÷ groupcount (cdims))
196
+ channels_out (cdims) ÷ groupcount (cdims))
190
197
cdims2 = basetype (C)(cdims,
191
- G = 1 ,
192
- C_in = channels_in (cdims) ÷ groupcount (cdims),
193
- C_out = channels_out (cdims) ÷ groupcount (cdims))
198
+ G = 1 ,
199
+ C_in = channels_in (cdims) ÷ groupcount (cdims),
200
+ C_out = channels_out (cdims) ÷ groupcount (cdims))
194
201
195
202
Threads. @sync for (xc, wc) in zip (x_cs, w_cs)
196
203
x = @view in1[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
@@ -205,51 +212,89 @@ for (front_name, backend) in (
205
212
end
206
213
207
214
# im2col-accelerated function forwarding definition
208
- function ∇conv_data! (out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
209
- in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: G , C <: ConvDims }
210
-
211
- dx_cs = Iterators. partition (1 : size (out, 4 ),
212
- channels_in (cdims) ÷ groupcount (cdims))
213
- w_cs = Iterators. partition (1 : size (in2, 5 ),
214
- channels_out (cdims) ÷ groupcount (cdims))
215
- dy_cs = Iterators. partition (1 : size (in1, 4 ),
216
- channels_out (cdims) ÷ groupcount (cdims))
217
- cdims2 = basetype (C)(cdims,
218
- G = 1 ,
219
- C_in = channels_in (cdims) ÷ groupcount (cdims),
220
- C_out = channels_out (cdims) ÷ groupcount (cdims))
221
-
222
- Threads. @sync for (xc, yc, wc) in zip (dx_cs, dy_cs, w_cs)
223
- dxv = @view out[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
224
- dyv = @view in1[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
225
- wv = @view in2[ntuple (i -> i == 5 ? wc : Colon (), 5 )... ]
226
- Threads. @spawn ∇conv_data_im2col! (dxv, dyv, wv, cdims2; kwargs... )
227
- end
215
+ for (front_name, backend, signature) in (
216
+ # This maps from public, front-facing name, to internal backend name
217
+ (:∇conv_data , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
218
+ (:∇conv_data , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
219
+ )
220
+ # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
221
+ @eval begin
222
+ # println($(Symbol(["$(i)" for i in "$(signature[5])"]...))...)
223
+ function $ (Symbol (" $(front_name) !" ))(
224
+ out:: AbstractArray{$(signature[1][1]), $(signature[1][2])} ,
225
+ in1:: AbstractArray{$(signature[2][1]), $(signature[1][2])} ,
226
+ in2:: AbstractArray{$(signature[3][1]), $(signature[1][2])} ,
227
+ cdims:: $ (signature[4 ]),
228
+ kwargs... ) where {$ (signature[5 ]. .. )}
229
+ if $ (string (backend)) == " direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
230
+ @warn string (" Slow fallback implementation invoked for " , $ (string (front_name)), " ! " ,
231
+ " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
232
+ end
228
233
229
- return out
230
- end
231
234
232
- function ∇conv_filter! (out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
233
- in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: G , C <: ConvDims }
234
- dw_cs = Iterators. partition (1 : size (out, 5 ),
235
- channels_out (cdims) ÷ groupcount (cdims))
236
- dy_cs = Iterators. partition (1 : size (in2, 4 ),
237
- channels_out (cdims) ÷ groupcount (cdims))
238
- x_cs = Iterators. partition (1 : size (in1, 4 ),
239
- channels_in (cdims) ÷ groupcount (cdims))
240
- cdims2 = basetype (C)(cdims,
241
- G = 1 ,
242
- C_in = channels_in (cdims) ÷ groupcount (cdims),
243
- C_out = channels_out (cdims) ÷ groupcount (cdims))
244
-
245
- Threads. @sync for (wc, xc, yc) in zip (dw_cs, x_cs, dy_cs)
246
- x = @view in1[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
247
- dy = @view in2[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
248
- dw = @view out[ntuple (i -> i == 5 ? yc : Colon (), 5 )... ]
249
- Threads. @spawn ∇conv_filter_im2col! (dw, x, dy, cdims2; kwargs... )
235
+ dx_cs = Iterators. partition (1 : size (out, 4 ),
236
+ channels_in (cdims) ÷ groupcount (cdims))
237
+ w_cs = Iterators. partition (1 : size (in2, 5 ),
238
+ channels_out (cdims) ÷ groupcount (cdims))
239
+ dy_cs = Iterators. partition (1 : size (in1, 4 ),
240
+ channels_out (cdims) ÷ groupcount (cdims))
241
+ cdims2 = basetype (C)(cdims,
242
+ G = 1 ,
243
+ C_in = channels_in (cdims) ÷ groupcount (cdims),
244
+ C_out = channels_out (cdims) ÷ groupcount (cdims))
245
+
246
+ Threads. @sync for (xc, yc, wc) in zip (dx_cs, dy_cs, w_cs)
247
+ dxv = @view out[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
248
+ dyv = @view in1[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
249
+ wv = @view in2[ntuple (i -> i == 5 ? wc : Colon (), 5 )... ]
250
+ Threads. @spawn $ (Symbol (" $(front_name) _$(backend) !" ))(dxv, dyv, wv, cdims2; kwargs... )
251
+ end
252
+
253
+ return out
254
+ end
250
255
end
256
+ end
251
257
252
- return out
258
+ for (front_name, backend, signature) in (
259
+ # This maps from public, front-facing name, to internal backend name
260
+ (:∇conv_filter , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
261
+ (:∇conv_filter , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
262
+ )
263
+ # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
264
+ @eval begin
265
+ # println($(Symbol(["$(i)" for i in "$(signature[5])"]...))...)
266
+ function $ (Symbol (" $(front_name) !" ))(
267
+ out:: AbstractArray{$(signature[1][1]), $(signature[1][2])} ,
268
+ in1:: AbstractArray{$(signature[2][1]), $(signature[1][2])} ,
269
+ in2:: AbstractArray{$(signature[3][1]), $(signature[1][2])} ,
270
+ cdims:: $ (signature[4 ]),
271
+ kwargs... ) where {$ (signature[5 ]. .. )}
272
+ if $ (string (backend)) == " direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
273
+ @warn string (" Slow fallback implementation invoked for " , $ (string (front_name)), " ! " ,
274
+ " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
275
+ end
276
+
277
+ dw_cs = Iterators. partition (1 : size (out, 5 ),
278
+ channels_out (cdims) ÷ groupcount (cdims))
279
+ dy_cs = Iterators. partition (1 : size (in2, 4 ),
280
+ channels_out (cdims) ÷ groupcount (cdims))
281
+ x_cs = Iterators. partition (1 : size (in1, 4 ),
282
+ channels_in (cdims) ÷ groupcount (cdims))
283
+ cdims2 = basetype (C)(cdims,
284
+ G = 1 ,
285
+ C_in = channels_in (cdims) ÷ groupcount (cdims),
286
+ C_out = channels_out (cdims) ÷ groupcount (cdims))
287
+
288
+ Threads. @sync for (wc, xc, yc) in zip (dw_cs, x_cs, dy_cs)
289
+ x = @view in1[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
290
+ dy = @view in2[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
291
+ dw = @view out[ntuple (i -> i == 5 ? yc : Colon (), 5 )... ]
292
+ Threads. @spawn $ (Symbol (" $(front_name) _$(backend) !" ))(dw, x, dy, cdims2; kwargs... )
293
+ end
294
+
295
+ return out
296
+ end
297
+ end
253
298
end
254
299
255
300
@@ -289,101 +334,61 @@ for front_name in (:depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_fil
289
334
end
290
335
end
291
336
292
- for (front_name, backend) in (
293
- # This maps from public, front-facing name, to internal backend name
294
- :conv => :direct ,
295
- # :∇conv_data => :direct,
296
- # :∇conv_filter => :direct,
297
- )
298
-
299
- # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
300
- @eval begin
301
- # im2col-accelerated function forwarding definition
302
- function $ (Symbol (" $(front_name) !" ))(
303
- out:: AbstractArray{yT,N} , in1:: AbstractArray{T1,N} ,
304
- in2:: AbstractArray{T2,N} , cdims:: C ;
305
- kwargs... ) where {yT, T1, T2, N, C <: ConvDims }
306
- if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
307
- @warn string (" Slow fallback implementation invoked for " , $ (string (front_name)), " ! " ,
308
- " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
309
- end
310
-
311
- x_cs = Iterators. partition (1 : size (in1, 4 ),
312
- channels_in (cdims) ÷ groupcount (cdims))
313
- w_cs = Iterators. partition (1 : size (in2, 5 ),
314
- channels_out (cdims) ÷ groupcount (cdims))
315
- cdims2 = basetype (C)(cdims,
316
- G = 1 ,
317
- C_in = channels_in (cdims) ÷ groupcount (cdims),
318
- C_out = channels_out (cdims) ÷ groupcount (cdims))
319
-
320
- Threads. @sync for (xc, wc) in zip (x_cs, w_cs)
321
- x = @view in1[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
322
- w = @view in2[ntuple (i -> i == 5 ? wc : Colon (), 5 )... ]
323
- y = @view out[ntuple (i -> i == 4 ? wc : Colon (), 5 )... ]
324
- Threads. @spawn $ (Symbol (" $(front_name) _$(backend) !" ))(y, x, w, cdims2; kwargs... )
325
- end
326
-
327
- return out
328
- end
329
- end
330
- end
331
-
332
- # direct function forwarding definition
333
- function ∇conv_data! (out:: AbstractArray{yT,N} , in1:: AbstractArray{T1,N} ,
334
- in2:: AbstractArray{T2,N} , cdims:: C ; kwargs... ) where {yT, T1, T2, N, C <: ConvDims }
335
- if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
336
- @warn string (" Slow fallback implementation invoked for " , string (front_name), " ! " ,
337
- " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
338
- end
337
+ # # direct function forwarding definition
338
+ # function ∇conv_data!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
339
+ # in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
340
+ # if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
341
+ # @warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
342
+ # "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
343
+ # end
339
344
340
- dx_cs = Iterators. partition (1 : size (out, 4 ),
341
- channels_in (cdims) ÷ groupcount (cdims))
342
- w_cs = Iterators. partition (1 : size (in2, 5 ),
343
- channels_out (cdims) ÷ groupcount (cdims))
344
- dy_cs = Iterators. partition (1 : size (in1, 4 ),
345
- channels_out (cdims) ÷ groupcount (cdims))
346
- cdims2 = basetype (C)(cdims,
347
- G = 1 ,
348
- C_in = channels_in (cdims) ÷ groupcount (cdims),
349
- C_out = channels_out (cdims) ÷ groupcount (cdims))
350
-
351
- Threads. @sync for (xc, yc, wc) in zip (dx_cs, dy_cs, w_cs)
352
- dxv = @view out[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
353
- dyv = @view in1[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
354
- wv = @view in2[ntuple (i -> i == 5 ? wc : Colon (), 5 )... ]
355
- Threads. @spawn ∇conv_data_direct! (dxv, dyv, wv, cdims2; kwargs... )
356
- end
345
+ # dx_cs = Iterators.partition(1:size(out, 4),
346
+ # channels_in(cdims) ÷ groupcount(cdims))
347
+ # w_cs = Iterators.partition(1:size(in2, 5),
348
+ # channels_out(cdims) ÷ groupcount(cdims))
349
+ # dy_cs = Iterators.partition(1:size(in1, 4),
350
+ # channels_out(cdims) ÷ groupcount(cdims))
351
+ # cdims2 = basetype(C)(cdims,
352
+ # G = 1,
353
+ # C_in = channels_in(cdims) ÷ groupcount(cdims),
354
+ # C_out = channels_out(cdims) ÷ groupcount(cdims))
355
+
356
+ # Threads.@sync for (xc, yc, wc) in zip(dx_cs, dy_cs, w_cs)
357
+ # dxv = @view out[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
358
+ # dyv = @view in1[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
359
+ # wv = @view in2[ntuple(i -> i == 5 ? wc : Colon(), 5)...]
360
+ # Threads.@spawn ∇conv_data_direct!(dxv, dyv, wv, cdims2; kwargs...)
361
+ # end
357
362
358
- return out
359
- end
363
+ # return out
364
+ # end
360
365
361
- function ∇conv_filter! (out:: AbstractArray{yT,N} , in1:: AbstractArray{T1,N} ,
362
- in2:: AbstractArray{T2,N} , cdims:: C ; kwargs... ) where {yT, T1, T2, N, C <: ConvDims }
363
- if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
364
- @warn string (" Slow fallback implementation invoked for " , string (front_name), " ! " ,
365
- " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
366
- end
367
- dw_cs = Iterators. partition (1 : size (out, 5 ),
368
- channels_out (cdims) ÷ groupcount (cdims))
369
- dy_cs = Iterators. partition (1 : size (in2, 4 ),
370
- channels_out (cdims) ÷ groupcount (cdims))
371
- x_cs = Iterators. partition (1 : size (in1, 4 ),
372
- channels_in (cdims) ÷ groupcount (cdims))
373
- cdims2 = basetype (C)(cdims,
374
- G = 1 ,
375
- C_in = channels_in (cdims) ÷ groupcount (cdims),
376
- C_out = channels_out (cdims) ÷ groupcount (cdims))
377
-
378
- Threads. @sync for (wc, xc, yc) in zip (dw_cs, x_cs, dy_cs)
379
- x = @view in1[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
380
- dy = @view in2[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
381
- dw = @view out[ntuple (i -> i == 5 ? yc : Colon (), 5 )... ]
382
- Threads. @spawn ∇conv_filter_direct! (dw, x, dy, cdims2; kwargs... )
383
- end
366
+ # function ∇conv_filter!(out::AbstractArray{yT,N}, in1::AbstractArray{T1,N},
367
+ # in2::AbstractArray{T2,N}, cdims::C; kwargs...) where {yT, T1, T2, N, C <: ConvDims}
368
+ # if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
369
+ # @warn string("Slow fallback implementation invoked for ", string(front_name), "! ",
370
+ # "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
371
+ # end
372
+ # dw_cs = Iterators.partition(1:size(out, 5),
373
+ # channels_out(cdims) ÷ groupcount(cdims))
374
+ # dy_cs = Iterators.partition(1:size(in2, 4),
375
+ # channels_out(cdims) ÷ groupcount(cdims))
376
+ # x_cs = Iterators.partition(1:size(in1, 4),
377
+ # channels_in(cdims) ÷ groupcount(cdims))
378
+ # cdims2 = basetype(C)(cdims,
379
+ # G = 1,
380
+ # C_in = channels_in(cdims) ÷ groupcount(cdims),
381
+ # C_out = channels_out(cdims) ÷ groupcount(cdims))
382
+
383
+ # Threads.@sync for (wc, xc, yc) in zip(dw_cs, x_cs, dy_cs)
384
+ # x = @view in1[ntuple(i -> i == 4 ? xc : Colon(), 5)...]
385
+ # dy = @view in2[ntuple(i -> i == 4 ? yc : Colon(), 5)...]
386
+ # dw = @view out[ntuple(i -> i == 5 ? yc : Colon(), 5)...]
387
+ # Threads.@spawn ∇conv_filter_direct!(dw, x, dy, cdims2; kwargs...)
388
+ # end
384
389
385
- return out
386
- end
390
+ # return out
391
+ # end
387
392
388
393
for Dims in [:DenseConvDims , :DepthwiseConvDims , :PoolDims ]
389
394
@eval @non_differentiable $ Dims (:: Any... )
0 commit comments