@@ -22,7 +22,7 @@ by setting `alpha` to a nonunitary value, various gain factors can be applied.
22
22
Note for the particularly performance-minded, you can provide a pre-allocated `col`,
23
23
which should eliminate any need for large allocations within this method.
24
24
"""
25
- @timeit_debug to function conv_im2col! (
25
+ function conv_im2col! (
26
26
y:: AbstractArray{T,5} , x:: AbstractArray{T,5} ,
27
27
w:: AbstractArray{T,5} , cdims:: DenseConvDims ;
28
28
col:: AbstractArray{T,2} = similar (x, im2col_dims (cdims)),
@@ -49,12 +49,12 @@ which should eliminate any need for large allocations within this method.
49
49
@inbounds for batch_idx in 1 : size (x,5 )
50
50
# We invoke `@timeit_debug` on the outside of `im2col!()` because inference
51
51
# doesn't like us putting it on the inside.
52
- @timeit_debug to " im2col! " im2col! (col, view (x, :, :, :, :, batch_idx), cdims)
52
+ im2col! (col, view (x, :, :, :, :, batch_idx), cdims)
53
53
GC. @preserve col, w, y, begin
54
54
col_ptr = pointer (col)
55
55
w_ptr = pointer (w)
56
56
y_ptr = pointer (y, (batch_idx - 1 )* M* N + 1 )
57
- @timeit_debug to " gemm! " gemm! (Val (false ), Val (false ), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
57
+ gemm! (Val (false ), Val (false ), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
58
58
end
59
59
end
60
60
return y
66
66
Conv backward pass onto the weights using im2col and GEMM; stores the result in `dw`.
67
67
See the documentation for `conv_im2col!()` for explanation of optional parameters.
68
68
"""
69
- @timeit_debug to function ∇conv_filter_im2col! (
69
+ function ∇conv_filter_im2col! (
70
70
dw:: AbstractArray{T,5} , x:: AbstractArray{T,5} ,
71
71
dy:: AbstractArray{T,5} , cdims:: DenseConvDims ;
72
72
col:: AbstractArray{T,2} = similar (dw, im2col_dims (cdims)),
@@ -95,14 +95,12 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter
95
95
K = prod (output_size (cdims))
96
96
97
97
@inbounds for batch_idx in 1 : size (x,5 )
98
- # We invoke `@timeit_debug` on the outside of `im2col!()` because inference
99
- # doesn't like us putting it on the inside.
100
- @timeit_debug to " im2col!" im2col! (col, view (x, :, :, :, :, batch_idx), cdims)
98
+ im2col! (col, view (x, :, :, :, :, batch_idx), cdims)
101
99
GC. @preserve col, dw, dy, begin
102
100
col_ptr = pointer (col)
103
101
dy_ptr = pointer (dy,(batch_idx - 1 )* K* N + 1 )
104
102
dw_ptr = pointer (dw)
105
- @timeit_debug to " gemm! " gemm! (Val (true ), Val (false ), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
103
+ gemm! (Val (true ), Val (false ), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
106
104
end
107
105
108
106
# Because we accumulate over batches in this loop, we must set `beta` equal
118
116
Conv2d backward pass onto the input using im2col and GEMM; stores the result in `dx`.
119
117
See the documentation for `conv_im2col!()` for explanation of other parameters.
120
118
"""
121
- @timeit_debug to function ∇conv_data_im2col! (
119
+ function ∇conv_data_im2col! (
122
120
dx:: AbstractArray{T,5} , dy:: AbstractArray{T,5} ,
123
121
w:: AbstractArray{T,5} , cdims:: DenseConvDims ;
124
122
col:: AbstractArray{T,2} = similar (dx, im2col_dims (cdims)),
@@ -149,9 +147,9 @@ See the documentation for `conv_im2col!()` for explanation of other parameters.
149
147
dy_ptr = pointer (dy, (batch_idx - 1 )* M* K + 1 )
150
148
w_ptr = pointer (w)
151
149
col_ptr = pointer (col)
152
- @timeit_debug to " gemm! " gemm! (Val (false ), Val (true ), M, N, K, alpha, dy_ptr, w_ptr, T (0 ), col_ptr)
150
+ gemm! (Val (false ), Val (true ), M, N, K, alpha, dy_ptr, w_ptr, T (0 ), col_ptr)
153
151
end
154
- @timeit_debug to " col2im! " col2im! (view (dx, :, :, :, :, batch_idx), col, cdims)
152
+ col2im! (view (dx, :, :, :, :, batch_idx), col, cdims)
155
153
end
156
154
return dx
157
155
end
@@ -207,77 +205,74 @@ function im2col!(col::AbstractArray{T,2}, x::AbstractArray{T,4},
207
205
# We begin by copying the central region of the image which requires no padding at all.
208
206
# Eliminating the branches of the fully generalized version below gives us a nice
209
207
# speedup on the majority of the data.
210
- @timeit_debug to " im2col!() - central region" begin
211
- @inbounds for c in 1 : C_in
212
- # Unpack "central region"
213
- w_region, h_region, d_region = central_region
214
-
215
- for kd in 1 : kernel_d,
216
- kh in 1 : kernel_h,
217
- kw in 1 : kernel_w,
218
- d in d_region,
219
- h in h_region,
220
- w in w_region
221
-
222
- input_kd = project (d, stride_d, pad_d_lo) + (kd - 1 )* dil_d
223
- input_kh = project (h, stride_h, pad_h_lo) + (kh - 1 )* dil_h
224
- input_kw = project (w, stride_w, pad_w_lo) + (kw - 1 )* dil_w
225
- kidxs = kernel_index (kw, kh, kd, cdims)
208
+ @inbounds for c in 1 : C_in
209
+ # Unpack "central region"
210
+ w_region, h_region, d_region = central_region
226
211
227
- xval:: T = x[input_kw, input_kh, input_kd, c]
228
- col_reshaped[w, h, d, kidxs... , c] = xval
229
- end
212
+ for kd in 1 : kernel_d,
213
+ kh in 1 : kernel_h,
214
+ kw in 1 : kernel_w,
215
+ d in d_region,
216
+ h in h_region,
217
+ w in w_region
218
+
219
+ input_kd = project (d, stride_d, pad_d_lo) + (kd - 1 )* dil_d
220
+ input_kh = project (h, stride_h, pad_h_lo) + (kh - 1 )* dil_h
221
+ input_kw = project (w, stride_w, pad_w_lo) + (kw - 1 )* dil_w
222
+ kidxs = kernel_index (kw, kh, kd, cdims)
223
+
224
+ xval:: T = x[input_kw, input_kh, input_kd, c]
225
+ col_reshaped[w, h, d, kidxs... , c] = xval
230
226
end
231
227
end
232
228
229
+
233
230
# For each "padded region", we run the fully general version
234
- @timeit_debug to " im2col!() - padded region" begin
235
- @inbounds for (w_region, h_region, d_region) in padded_regions
236
- for c in 1 : C_in,
237
- d in d_region,
238
- h in h_region,
239
- w in w_region,
240
- kd in 1 : kernel_d,
241
- kh in 1 : kernel_h,
242
- kw in 1 : kernel_w
231
+ @inbounds for (w_region, h_region, d_region) in padded_regions
232
+ for c in 1 : C_in,
233
+ d in d_region,
234
+ h in h_region,
235
+ w in w_region,
236
+ kd in 1 : kernel_d,
237
+ kh in 1 : kernel_h,
238
+ kw in 1 : kernel_w
243
239
244
- input_kd = project (d, stride_d, pad_d_lo) + (kd - 1 )* dil_d
245
- input_kh = project (h, stride_h, pad_h_lo) + (kh - 1 )* dil_h
246
- input_kw = project (w, stride_w, pad_w_lo) + (kw - 1 )* dil_w
240
+ input_kd = project (d, stride_d, pad_d_lo) + (kd - 1 )* dil_d
241
+ input_kh = project (h, stride_h, pad_h_lo) + (kh - 1 )* dil_h
242
+ input_kw = project (w, stride_w, pad_w_lo) + (kw - 1 )* dil_w
247
243
248
- kidxs = kernel_index (kw, kh, kd, cdims)
244
+ kidxs = kernel_index (kw, kh, kd, cdims)
249
245
250
- # If this d is off the edge, then deal with the entire plane
251
- # in one fell swoop, like a ravenous flock of crows. CAW CAW.
252
- if input_kd <= 0 || input_kd > depth
253
- for kh in 1 : kernel_h,
254
- kw in 1 : kernel_w
255
- col_reshaped[w, h, d, kidxs... , c] = T (0 )
256
- end
257
- continue
258
- end
259
-
260
- # Same for `h`, but in this case it's only a line, not a plane.
261
- # This results in slightly less caw'ing.
262
- if input_kh <= 0 || input_kh > height
263
- for kw in 1 : kernel_w
264
- col_reshaped[w, h, d, kidxs... , c] = T (0 )
265
- end
266
- continue
246
+ # If this d is off the edge, then deal with the entire plane
247
+ # in one fell swoop, like a ravenous flock of crows. CAW CAW.
248
+ if input_kd <= 0 || input_kd > depth
249
+ for kh in 1 : kernel_h,
250
+ kw in 1 : kernel_w
251
+ col_reshaped[w, h, d, kidxs... , c] = T (0 )
267
252
end
253
+ continue
254
+ end
268
255
269
- # If this `w` is off the edge it and only it gets cleared out
270
- if input_kw <= 0 || input_kw > width
256
+ # Same for `h`, but in this case it's only a line, not a plane.
257
+ # This results in slightly less caw'ing.
258
+ if input_kh <= 0 || input_kh > height
259
+ for kw in 1 : kernel_w
271
260
col_reshaped[w, h, d, kidxs... , c] = T (0 )
272
- continue
273
261
end
262
+ continue
263
+ end
274
264
275
- # Copy the data over
276
- xval:: T = x[input_kw, input_kh, input_kd, c]
277
- col_reshaped[w, h, d, kidxs... , c] = xval
265
+ # If this `w` is off the edge it and only it gets cleared out
266
+ if input_kw <= 0 || input_kw > width
267
+ col_reshaped[w, h, d, kidxs... , c] = T (0 )
268
+ continue
278
269
end
270
+
271
+ # Copy the data over
272
+ xval:: T = x[input_kw, input_kh, input_kd, c]
273
+ col_reshaped[w, h, d, kidxs... , c] = xval
279
274
end
280
- end
275
+ end
281
276
end
282
277
283
278
0 commit comments