@@ -10,32 +10,112 @@ function psize(p, x)
10
10
end
11
11
end
12
12
13
- function im2col_2d! (img:: AbstractArray{T,3} , col:: AbstractArray{T,2} , width:: Int , height:: Int , channels:: Int ,
14
- kernel_w:: Int , kernel_h:: Int , pad_w:: Int , pad_h:: Int , stride_w:: Int , stride_h:: Int ,
15
- dil_w:: Int , dil_h:: Int , mode:: Int ) where T
16
-
17
- height_col = div (height + 2 pad_h - (kernel_h - 1 ) * dil_h - 1 , stride_h) + 1
18
- width_col = div (width + 2 pad_w - (kernel_w - 1 ) * dil_w - 1 , stride_w) + 1
19
- channels_col = channels * kernel_h * kernel_w
13
+ # Type system-level information about convolution dimensions. Critical for things like
14
+ # im2col_2d!() to generate efficient code.
15
+ struct ConvDims{img, kernel, channels, stride, padding, dilation, flipkernel} end
16
+ img_size (c:: ConvDims{I,K,C,S,P,D,F} ) where {I, K, C, S, P, D, F} = I
17
+
18
+ # Calculate the output dimensions of this convolution
19
+ function output_size (c:: ConvDims{I,K,C,S,P,D,F} ) where {I, K, C, S, P, D, F}
20
+ O_w = div (I[1 ] + P[1 ] + P[2 ] - (K[1 ] - 1 ) * D[1 ] - 1 , S[1 ]) + 1
21
+ O_h = div (I[2 ] + P[3 ] + P[4 ] - (K[1 ] - 1 ) * D[1 ] - 1 , S[1 ]) + 1
22
+ return (O_w, O_h)
23
+ end
24
+ kernel_size (c:: ConvDims{I,K,C,S,P,D,F} ) where {I, K, C, S, P, D, F} = K
25
+ img_channels (c:: ConvDims{I,K,C,S,P,D,F} ) where {I, K, C, S, P, D, F} = C
26
+ stride (c:: ConvDims{I,K,C,S,P,D,F} ) where {I, K, C, S, P, D, F} = S
27
+ padding (c:: ConvDims{I,K,C,S,P,D,F} ) where {I, K, C, S, P, D, F} = P
28
+ dilation (c:: ConvDims{I,K,C,S,P,D,F} ) where {I, K, C, S, P, D, F} = D
29
+ flipkernel (c:: ConvDims{I,K,C,S,P,D,F} ) where {I, K, C, S, P, D, F} = F
30
+
31
+ function im2col_2d! (img:: AbstractArray{T,3} , col:: AbstractArray{T,2} , cdims:: ConvDims ) where T
32
+ width, height = img_size (cdims)
33
+ kernel_w, kernel_h = kernel_size (cdims)
34
+ channels = img_channels (cdims)
35
+ pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi = padding (cdims)
36
+ dil_w, dil_h = dilation (cdims)
37
+ stride_w, stride_h = stride (cdims)
38
+ width_col, height_col = output_size (cdims)
39
+
40
+ if flipkernel (cdims)
41
+ flipk = (w, h) -> (kernel_w - w + 1 , kernel_h - h + 1 )
42
+ else
43
+ flipk = (w, h) -> (w, h)
44
+ end
20
45
21
- # pragma omp parallel for
22
- for c = 1 : channels_col
23
- w_offset = (c - 1 ) % kernel_w
24
- h_offset = div (c - 1 , kernel_w) % kernel_h
25
- c_im = div (c - 1 , kernel_h * kernel_w)
26
- if mode == 0
27
- w_offset = kernel_w - 1 - w_offset
28
- h_offset = kernel_h - 1 - h_offset
46
+ # Reshape col for easy access.
47
+ col_reshaped = reshape (col, (width_col, height_col, kernel_w, kernel_h, channels))
48
+
49
+ # Let us first calculate the number of rows/columns within which we must zero out some
50
+ # portion of the image patches we're copying over. Note the subtractions on the `_hi`
51
+ # variants are due to us needing to account for padding that is completely ignored due
52
+ # to stride/dilation/kernel size combinations.
53
+ spill_w_lo = ceil (Int, pad_w_lo/ stride_w)
54
+ spill_w_hi = width_col - div (width + pad_w_lo - (kernel_w - 1 )* dil_w, stride_w)
55
+ spill_h_lo = ceil (Int, pad_h_lo/ stride_h)
56
+ spill_h_hi = height_col - div (height + pad_h_lo - (kernel_h - 1 )* dil_h, stride_h)
57
+ spill_w_hi_abs = width_col - spill_w_hi + 1
58
+ spill_h_hi_abs = height_col - spill_h_hi + 1
59
+
60
+ # First, a helper function to project from output (w, h) to input (input_w, input_h)
61
+ project (idx, stride, pad) = (idx - 1 )* stride - pad + 1
62
+
63
+ # These are the regions we're going to have to run with cognizance of padding
64
+ padded_regions = (
65
+ (1 : width_col, 1 : spill_h_lo),
66
+ (1 : spill_w_lo, (spill_h_lo+ 1 ): (spill_h_hi_abs- 1 )),
67
+ (spill_w_hi_abs: width_col, (spill_h_lo+ 1 ): (spill_h_hi_abs- 1 )),
68
+ (1 : width_col, spill_h_hi_abs: height_col),
69
+ )
70
+
71
+ # We begin by copying the central region of the image which requires no padding at all.
72
+ # Eliminating the branches of the fully generalized version below gives us a nice
73
+ # speedup on the majority of the data.
74
+ for c in 1 : channels
75
+ for kh in 1 : kernel_h
76
+ for kw in 1 : kernel_w
77
+ for h in (spill_h_lo+ 1 ): (height_col - spill_h_hi)
78
+ input_kh = project (h, stride_h, pad_h_lo) + (kh - 1 )* dil_h
79
+
80
+ @inbounds for w in (spill_w_lo+ 1 ): (width_col - spill_w_hi)
81
+ input_kw = project (w, stride_w, pad_w_lo) + (kw - 1 )* dil_w
82
+ col_reshaped[w, h, flipk (kw, kh)... , c] = img[input_kw, input_kh, c]
83
+ end
84
+ end
85
+ end
29
86
end
30
- for h = 1 : height_col
31
- for w = 1 : width_col
32
- h_pad = (h - 1 ) * stride_h - pad_h + h_offset * dil_h
33
- w_pad = (w - 1 ) * stride_w - pad_w + w_offset * dil_w
34
- if h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width
35
- col[((c - 1 )* height_col+ h- 1 ) * width_col + w] =
36
- img[(c_im * height + h_pad) * width + w_pad + 1 ]
37
- else
38
- col[((c - 1 )* height_col+ h - 1 ) * width_col + w] = 0
87
+ end
88
+
89
+ # For each "padded region", we run the fully general version
90
+ for (w_region, h_region) in padded_regions
91
+ for c in 1 : channels
92
+ for kh in 1 : kernel_h
93
+ for kw in 1 : kernel_w
94
+ @inbounds for h in h_region
95
+ input_kh = project (h, stride_h, pad_h_lo) + (kh - 1 )* dil_h
96
+
97
+ # If this column is off the edge, then deal with the entire thing
98
+ # in one fell swoop, like a ravenous flock of crows. CAW CAW.
99
+ if input_kh <= 0 || input_kh > height
100
+ for w in w_region
101
+ col_reshaped[w, h, flipk (kw, kh)... , c] = zero (eltype (col_reshaped))
102
+ end
103
+ continue
104
+ end
105
+
106
+ @inbounds for w in w_region
107
+ input_kw = project (w, stride_w, pad_w_lo) + (kw - 1 )* dil_w
108
+
109
+ # If this pixel is off the edge of the map, clear it out.
110
+ if input_kw <= 0 || input_kw > width
111
+ col_reshaped[w, h, flipk (kw, kh)... , c] = zero (eltype (col_reshaped))
112
+ continue
113
+ end
114
+
115
+ # Copy the data over
116
+ col_reshaped[w, h, flipk (kw, kh)... , c] = img[input_kw, input_kh, c]
117
+ end
118
+ end
39
119
end
40
120
end
41
121
end
@@ -256,26 +336,41 @@ function depthwiseconv2d_grad_x!(dx::AbstractArray{T,4}, x::AbstractArray{T,4},
256
336
return dx
257
337
end
258
338
339
+ function conv2d! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ,
340
+ cdims:: ConvDims ; alpha= T (1 )) where T
341
+ Wx, Hx = img_size (cdims)
342
+ Ww, Hw = kernel_size (cdims)
343
+ Wy, Hy = output_size (cdims)
344
+ Cx = img_channels (cdims)
345
+ M, N, K, Y = Wy* Hy, size (y,4 ), Ww* Hw* Cx, Wy* Hy* size (y, 4 )
346
+
347
+ x2 = similar (x, im2col_dims (w, y))
348
+ @inbounds for n in 1 : size (x,4 )
349
+ im2col_2d! (view (x, :, :, :, n), x2, cdims)
350
+ gemm! (' N' ,' N' ,M,N,K,alpha,pointer (x2),pointer (w),T (0 ),pointer (y,(n - 1 )* Y + 1 ))
351
+ end
352
+ return y
353
+ end
354
+
259
355
function conv2d! (y:: AbstractArray{T,4} , x:: AbstractArray{T,4} , w:: AbstractArray{T,4} ;
260
356
padding= 0 , stride= 1 , dilation= 1 , mode= 0 , alpha= T (1 )) where T
261
- if mode != 0 && mode != 1 ; throw (ArgumentError (" conv2d only supports mode=0 or 1." )); end
357
+ if mode != 0 && mode != 1
358
+ throw (ArgumentError (" conv2d only supports mode=0 or 1." ))
359
+ end
262
360
Wx,Hx,Cx,Nx = size (x)
263
361
Ww,Hw,C1,C2 = size (w)
264
- if Cx!= C1; throw (DimensionMismatch ()); end
265
- Wy,Hy,Cy,Ny = size (y)
266
- x2dims = im2col_dims (w,y)
267
- x2 = similar (x, x2dims)
362
+
363
+ # Check that the number of channels in `x` matches the number of channels in each
364
+ # kernel of `w`. IF it doesn't, throw a DimensionMismatch()
365
+ if Cx != C1
366
+ throw (DimensionMismatch ())
367
+ end
268
368
(p1,p2) = psize (padding,x)
269
369
(s1,s2) = psize (stride,x)
270
370
(d1,d2) = psize (dilation, x)
271
- M,N,K,Y = Wy* Hy,Cy,Ww* Hw* Cx,Wy* Hy* Cy
272
- yidx = 1
273
- @inbounds for n in 1 : Nx
274
- im2col2d! (w, x, x2, n, p1, p2, s1, s2, d1, d2, mode)
275
- gemm! (' N' ,' N' ,M,N,K,alpha,pointer (x2),pointer (w),T (0 ),pointer (y,yidx))
276
- yidx += Y
277
- end
278
- return y
371
+
372
+ cdims = ConvDims {(Wx,Hx),(Ww,Hw),Cx,(s1,s2),(p1,p1,p2,p2),(d1,d2), mode == 0} ()
373
+ return conv2d! (y, x, w, cdims; alpha= alpha)
279
374
end
280
375
281
376
function conv2d_grad_w! (dw:: AbstractArray{T,4} , x:: AbstractArray{T,4} , dy:: AbstractArray{T,4} ;
@@ -332,37 +427,37 @@ function im2col2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T,
332
427
n:: Int , p1:: Int , p2:: Int , s1:: Int , s2:: Int , mode:: Int ) where T
333
428
Wx,Hx,Cx,Nx = size (x)
334
429
Ww,Hw,C1,C2 = w
335
- xn = x[:, :, :, n]
336
- im2col_2d! (xn,x2,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1 ,1 ,mode)
430
+ xn = view (x, :, :, :, n)
431
+ cdims = ConvDims {(Wx,Hx),(Ww,Hw),Cx,(s1,s2),(p1,p1,p2,p2),(1,1), mode == 0} ()
432
+ im2col_2d! (xn,x2,cdims)
337
433
return x2
338
434
end
339
435
340
436
function im2col2d! (w:: AbstractArray{T,4} , x:: AbstractArray{T,4} , x2:: AbstractArray{T,2} ,
341
437
n:: Int , p1:: Int , p2:: Int , s1:: Int , s2:: Int , d1:: Int , d2:: Int , mode:: Int ) where T
342
438
Wx,Hx,Cx,Nx = size (x)
343
439
Ww,Hw,C1,C2 = size (w)
344
- xn = x[:, :, :, n]
345
- im2col_2d! (xn,x2,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,d1,d2,mode)
440
+ xn = view (x, :, :, :, n)
441
+ cdims = ConvDims {(Wx,Hx),(Ww,Hw),Cx,(s1,s2),(p1,p1,p2,p2),(d1,d2), mode == 0} ()
442
+ im2col_2d! (xn,x2,cdims)
346
443
return x2
347
444
end
348
445
349
446
function col2im2d! (w:: NTuple{4,Int} , x:: AbstractArray{T,4} , x2:: AbstractArray{T,2} ,
350
447
n:: Int , p1:: Int , p2:: Int , s1:: Int , s2:: Int , mode:: Int ) where T
351
448
Wx,Hx,Cx,Nx = size (x)
352
449
Ww,Hw,C1,C2 = w
353
- xn = x[ :, :, :, n]
450
+ xn = view (x, :, :, :, n)
354
451
col2im_2d! (x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1 ,1 ,mode)
355
- x[:, :, :, n] .= xn
356
452
return x
357
453
end
358
454
359
455
function col2im2d! (w:: AbstractArray{T,4} , x:: AbstractArray{T,4} , x2:: AbstractArray{T,2} ,
360
456
n:: Int , p1:: Int , p2:: Int , s1:: Int , s2:: Int , d1:: Int , d2:: Int , mode:: Int ) where T
361
457
Wx,Hx,Cx,Nx = size (x)
362
458
Ww,Hw,C1,C2 = size (w)
363
- xn = x[ :, :, :, n]
459
+ xn = view (x, :, :, :, n)
364
460
col2im_2d! (x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,d1,d2,mode)
365
- x[:, :, :, n] .= xn
366
461
return x
367
462
end
368
463
@@ -445,7 +540,7 @@ function im2col3d!(w::AbstractArray{T,5}, x::AbstractArray{T,5}, x2::AbstractArr
445
540
s3:: Int , d1:: Int , d2:: Int , d3:: Int , mode:: Int ) where T
446
541
Wx,Hx,Dx,Cx,Nx = size (x)
447
542
Ww,Hw,Dw,C1,C2 = size (w)
448
- xn = x[ :, :, :, :, n]
543
+ xn = view (x, :, :, :, :, n)
449
544
im2col_3d! (xn,x2,Wx,Hx,Dx,Cx,Ww,Hw,Dw,p1,p2,p3,s1,s2,s3,d1,d2,d3,mode)
450
545
return x2
451
546
end
@@ -455,8 +550,7 @@ function col2im3d!(w::AbstractArray{T,5}, x::AbstractArray{T,5}, x2::AbstractArr
455
550
s3:: Int , d1:: Int , d2:: Int , d3:: Int , mode:: Int ) where T
456
551
Wx,Hx,Dx,Cx,Nx = size (x)
457
552
Ww,Hw,Dw,C1,C2 = size (w)
458
- xn = x[ :, :, :, :, n]
553
+ xn = view (x, :, :, :, :, n)
459
554
col2im_3d! (x2,xn,Wx,Hx,Dx,Cx,Ww,Hw,Dw,p1,p2,p3,s1,s2,s3,d1,d2,d3,mode)
460
- x[:, :, :, :, n] = xn
461
555
return x
462
556
end
0 commit comments