@@ -166,31 +166,41 @@ end
166
166
167
167
# First, we will define mappings from the generic API names to our accelerated backend
168
168
# implementations. For homogeneous-datatype 1, 2 and 3d convolutions, we default to using
169
- # im2col + GEMM. Do so in a loop, here:
169
+ # im2col + GEMM.
170
+ # But we always support a fallback, non-accelerated path, where we use the direct, but
171
+ # slow, implementations. These should not typically be used, hence the `@warn`,
170
172
171
173
# These are the GEMM types we will accelerate with `im2col`
172
174
const G = Union{[x[2 ] for x in gemm_datatype_mappings]. .. }
173
175
174
- for (front_name, backend) in (
175
- # This maps from public, front-facing name, to internal backend name
176
- :conv => :im2col ,
177
- )
178
-
176
+ for (front_name, backend, signature) in (
177
+ # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
178
+ # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
179
+ (:conv , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
180
+ (:conv , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
181
+ )
179
182
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
180
183
@eval begin
181
- # im2col-accelerated function forwarding definition
184
+
182
185
function $ (Symbol (" $(front_name) !" ))(
183
- out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
184
- in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: $G , C <: ConvDims }
186
+ out:: AbstractArray{$(signature[1][1]), $(signature[1][2])} ,
187
+ in1:: AbstractArray{$(signature[2][1]), $(signature[1][2])} ,
188
+ in2:: AbstractArray{$(signature[3][1]), $(signature[1][2])} ,
189
+ cdims:: $ (signature[4 ]);
190
+ kwargs... ) where {$ (signature[5 ]. .. )}
191
+ if $ (string (backend)) == " direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
192
+ @warn string (" Slow fallback implementation invoked for " , $ (string (front_name)), " ! " ,
193
+ " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
194
+ end
185
195
186
196
x_cs = Iterators. partition (1 : size (in1, 4 ),
187
- channels_in (cdims) ÷ groupcount (cdims))
197
+ channels_in (cdims) ÷ groupcount (cdims))
188
198
w_cs = Iterators. partition (1 : size (in2, 5 ),
189
- channels_out (cdims) ÷ groupcount (cdims))
199
+ channels_out (cdims) ÷ groupcount (cdims))
190
200
cdims2 = basetype (C)(cdims,
191
- G = 1 ,
192
- C_in = channels_in (cdims) ÷ groupcount (cdims),
193
- C_out = channels_out (cdims) ÷ groupcount (cdims))
201
+ G = 1 ,
202
+ C_in = channels_in (cdims) ÷ groupcount (cdims),
203
+ C_out = channels_out (cdims) ÷ groupcount (cdims))
194
204
195
205
Threads. @sync for (xc, wc) in zip (x_cs, w_cs)
196
206
x = @view in1[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
@@ -205,87 +215,119 @@ for (front_name, backend) in (
205
215
end
206
216
207
217
# im2col-accelerated function forwarding definition
208
- function ∇conv_data! (out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
209
- in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: G , C <: ConvDims }
210
-
211
- dx_cs = Iterators. partition (1 : size (out, 4 ),
212
- channels_in (cdims) ÷ groupcount (cdims))
213
- w_cs = Iterators. partition (1 : size (in2, 5 ),
214
- channels_out (cdims) ÷ groupcount (cdims))
215
- dy_cs = Iterators. partition (1 : size (in1, 4 ),
216
- channels_out (cdims) ÷ groupcount (cdims))
217
- cdims2 = basetype (C)(cdims,
218
- G = 1 ,
219
- C_in = channels_in (cdims) ÷ groupcount (cdims),
220
- C_out = channels_out (cdims) ÷ groupcount (cdims))
221
-
222
- Threads. @sync for (xc, yc, wc) in zip (dx_cs, dy_cs, w_cs)
223
- dxv = @view out[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
224
- dyv = @view in1[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
225
- wv = @view in2[ntuple (i -> i == 5 ? wc : Colon (), 5 )... ]
226
- Threads. @spawn ∇conv_data_im2col! (dxv, dyv, wv, cdims2; kwargs... )
227
- end
218
+ for (front_name, backend, signature) in (
219
+ # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
220
+ # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
221
+ (:∇conv_data , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
222
+ (:∇conv_data , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
223
+ )
224
+ # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
225
+ @eval begin
226
+ function $ (Symbol (" $(front_name) !" ))(
227
+ out:: AbstractArray{$(signature[1][1]), $(signature[1][2])} ,
228
+ in1:: AbstractArray{$(signature[2][1]), $(signature[1][2])} ,
229
+ in2:: AbstractArray{$(signature[3][1]), $(signature[1][2])} ,
230
+ cdims:: $ (signature[4 ]);
231
+ kwargs... ) where {$ (signature[5 ]. .. )}
232
+ if $ (string (backend)) == " direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
233
+ @warn string (" Slow fallback implementation invoked for " , $ (string (front_name)), " ! " ,
234
+ " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
235
+ end
228
236
229
- return out
230
- end
231
237
232
- function ∇conv_filter! (out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
233
- in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: G , C <: ConvDims }
234
- dw_cs = Iterators. partition (1 : size (out, 5 ),
235
- channels_out (cdims) ÷ groupcount (cdims))
236
- dy_cs = Iterators. partition (1 : size (in2, 4 ),
237
- channels_out (cdims) ÷ groupcount (cdims))
238
- x_cs = Iterators. partition (1 : size (in1, 4 ),
239
- channels_in (cdims) ÷ groupcount (cdims))
240
- cdims2 = basetype (C)(cdims,
241
- G = 1 ,
242
- C_in = channels_in (cdims) ÷ groupcount (cdims),
243
- C_out = channels_out (cdims) ÷ groupcount (cdims))
244
-
245
- Threads. @sync for (wc, xc, yc) in zip (dw_cs, x_cs, dy_cs)
246
- x = @view in1[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
247
- dy = @view in2[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
248
- dw = @view out[ntuple (i -> i == 5 ? yc : Colon (), 5 )... ]
249
- Threads. @spawn ∇conv_filter_im2col! (dw, x, dy, cdims2; kwargs... )
250
- end
238
+ dx_cs = Iterators. partition (1 : size (out, 4 ),
239
+ channels_in (cdims) ÷ groupcount (cdims))
240
+ w_cs = Iterators. partition (1 : size (in2, 5 ),
241
+ channels_out (cdims) ÷ groupcount (cdims))
242
+ dy_cs = Iterators. partition (1 : size (in1, 4 ),
243
+ channels_out (cdims) ÷ groupcount (cdims))
244
+ cdims2 = basetype (C)(cdims,
245
+ G = 1 ,
246
+ C_in = channels_in (cdims) ÷ groupcount (cdims),
247
+ C_out = channels_out (cdims) ÷ groupcount (cdims))
248
+
249
+ Threads. @sync for (xc, yc, wc) in zip (dx_cs, dy_cs, w_cs)
250
+ dxv = @view out[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
251
+ dyv = @view in1[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
252
+ wv = @view in2[ntuple (i -> i == 5 ? wc : Colon (), 5 )... ]
253
+ Threads. @spawn $ (Symbol (" $(front_name) _$(backend) !" ))(dxv, dyv, wv, cdims2; kwargs... )
254
+ end
251
255
252
- return out
256
+ return out
257
+ end
258
+ end
253
259
end
254
260
255
-
256
- for (front_name, backend) in (
257
- # This maps from public, front-facing name, to internal backend name
258
- :depthwiseconv => :im2col ,
259
- :∇depthwiseconv_data => :im2col ,
260
- :∇depthwiseconv_filter => :im2col ,
261
- )
262
-
261
+ for (front_name, backend, signature) in (
262
+ # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
263
+ # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
264
+ (:∇conv_filter , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
265
+ (:∇conv_filter , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
266
+ )
263
267
# We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
264
268
@eval begin
265
- # im2col-accelerated function forwarding definition
266
269
function $ (Symbol (" $(front_name) !" ))(
267
- out:: AbstractArray{T,5} , in1:: AbstractArray{T,5} ,
268
- in2:: AbstractArray{T,5} , cdims:: C ; kwargs... ) where {T <: $G , C <: ConvDims }
269
- $ (Symbol (" $(front_name) _$(backend) !" ))(out, in1, in2, cdims; kwargs... )
270
+ out:: AbstractArray{$(signature[1][1]), $(signature[1][2])} ,
271
+ in1:: AbstractArray{$(signature[2][1]), $(signature[1][2])} ,
272
+ in2:: AbstractArray{$(signature[3][1]), $(signature[1][2])} ,
273
+ cdims:: $ (signature[4 ]);
274
+ kwargs... ) where {$ (signature[5 ]. .. )}
275
+ if $ (string (backend)) == " direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
276
+ @warn string (" Slow fallback implementation invoked for " , $ (string (front_name)), " ! " ,
277
+ " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
278
+ end
279
+
280
+ dw_cs = Iterators. partition (1 : size (out, 5 ),
281
+ channels_out (cdims) ÷ groupcount (cdims))
282
+ dy_cs = Iterators. partition (1 : size (in2, 4 ),
283
+ channels_out (cdims) ÷ groupcount (cdims))
284
+ x_cs = Iterators. partition (1 : size (in1, 4 ),
285
+ channels_in (cdims) ÷ groupcount (cdims))
286
+ cdims2 = basetype (C)(cdims,
287
+ G = 1 ,
288
+ C_in = channels_in (cdims) ÷ groupcount (cdims),
289
+ C_out = channels_out (cdims) ÷ groupcount (cdims))
290
+
291
+ Threads. @sync for (wc, xc, yc) in zip (dw_cs, x_cs, dy_cs)
292
+ x = @view in1[ntuple (i -> i == 4 ? xc : Colon (), 5 )... ]
293
+ dy = @view in2[ntuple (i -> i == 4 ? yc : Colon (), 5 )... ]
294
+ dw = @view out[ntuple (i -> i == 5 ? yc : Colon (), 5 )... ]
295
+ Threads. @spawn $ (Symbol (" $(front_name) _$(backend) !" ))(dw, x, dy, cdims2; kwargs... )
296
+ end
297
+
298
+ return out
270
299
end
271
300
end
272
301
end
273
302
274
- # We always support a fallback, non-accelerated path, where we use the direct, but
275
- # slow, implementations. These should not typically be used, hence the `@warn`,
276
- # but let's go ahead and define them first:
277
- for front_name in (:conv , :∇conv_data , :∇conv_filter ,
278
- :depthwiseconv , :∇depthwiseconv_data , :∇depthwiseconv_filter )
303
+
304
+ for (front_name, backend, signature) in (
305
+ # This maps from public, front-facing name, to internal backend name, given the function signature and the where clause
306
+ # (frontend, backend, (out Array signature, in1 Array signature, in2 Array signature, (parametric Types)))
307
+ (:depthwiseconv , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
308
+ (:depthwiseconv , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
309
+
310
+ (:∇depthwiseconv_data , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
311
+ (:∇depthwiseconv_data , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
312
+
313
+ (:∇depthwiseconv_filter , :im2col , ((:T , 5 ), (:T , 5 ), (:T , 5 ), :C , (:(T <: G ), :(C <: ConvDims )))),
314
+ (:∇depthwiseconv_filter , :direct , ((:yT , :N ), (:T1 , :N ), (:T2 , :N ), :C , (:yT , :T1 , :T2 , :N , :(C <: ConvDims )))),
315
+ )
316
+
317
+ # We only define 3d conv primitives, we reshape lower down to get 1d and 2d convolution
279
318
@eval begin
319
+ # im2col-accelerated function forwarding definition
280
320
function $ (Symbol (" $(front_name) !" ))(
281
- y:: AbstractArray{yT,N} , in1:: AbstractArray{T1,N} ,
282
- in2:: AbstractArray{T2,N} , cdims:: ConvDims ;
283
- kwargs... ) where {yT, T1, T2, N}
284
- if yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
321
+ out:: AbstractArray{$(signature[1][1]), $(signature[1][2])} ,
322
+ in1:: AbstractArray{$(signature[2][1]), $(signature[1][2])} ,
323
+ in2:: AbstractArray{$(signature[3][1]), $(signature[1][2])} ,
324
+ cdims:: $ (signature[4 ]);
325
+ kwargs... ) where {$ (signature[5 ]. .. )}
326
+ if $ (string (backend)) == " direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
285
327
@warn string (" Slow fallback implementation invoked for " , $ (string (front_name)), " ! " ,
286
- " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
328
+ " You probably don't want this; check your datatypes." ) yT T1 T2 maxlog= 1
287
329
end
288
- $ (Symbol (" $(front_name) _direct !" ))(y , in1, in2, cdims; kwargs... )
330
+ $ (Symbol (" $(front_name) _ $(backend) !" ))(out , in1, in2, cdims; kwargs... )
289
331
end
290
332
end
291
333
end
0 commit comments