1
1
2
- function unfold_kernel! (T :: Type , col, x, cdims, max_idx)
2
+ function unfold_kernel! (col :: AbstractArray{T} , x, col_size, input_size, output_size, kernel_size, flipkernel, stride, pad_lo, dilation, max_idx) where {T}
3
3
index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
4
4
5
- if index > max_idx
6
- return nothing
7
- end
8
-
9
- # Extract those nice, compile-time constant type parameters from `cdims`.
10
- width, height, depth = NNlib. input_size (cdims)
11
- kernel_w, kernel_h, kernel_d = NNlib. kernel_size (cdims)
12
- C_in = NNlib. channels_in (cdims)
13
- pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = NNlib. padding (cdims)
14
- dil_w, dil_h, dil_d = NNlib. dilation (cdims)
15
- stride_w, stride_h, stride_d = NNlib. stride (cdims)
16
- output_size = NNlib. output_size (cdims)
17
-
18
- I = CartesianIndices (output_size)
19
- w, h, d = I[index]. I # ouput spatial index indices
20
-
21
- # A helper function to project from output (w, h) to input (input_w, input_h)
22
- @inline project (idx, stride, pad) = (idx - 1 )* stride - pad + 1
5
+ @inbounds if index <= max_idx
6
+ i, kw, kh, kd, c, b = CartesianIndices (col_size)[index]. I # col indices
7
+ w, h, d = CartesianIndices (output_size)[i]. I # x indices
23
8
24
- @inbounds for c in 1 : C_in, b in 1 : size (x,5 )
25
- for kd in 1 : kernel_d,
26
- kh in 1 : kernel_h,
27
- kw in 1 : kernel_w
9
+ # project
10
+ w, h, d = @. ((w, h, d) - 1 )* stride - pad_lo + 1 + ((kw, kh, kd) - 1 )* dilation
28
11
29
- input_kd = project (d, stride_d, pad_d_lo) + (kd - 1 )* dil_d
30
- input_kh = project (h, stride_h, pad_h_lo) + (kh - 1 )* dil_h
31
- input_kw = project (w, stride_w, pad_w_lo) + (kw - 1 )* dil_w
32
-
33
- kidxs = NNlib. kernel_index (kw, kh, kd, cdims)
34
-
35
- out_of_bounds = (
36
- input_kd <= 0 || input_kd > depth ||
37
- input_kh <= 0 || input_kh > height ||
38
- input_kw <= 0 || input_kw > width
39
- )
40
- if out_of_bounds
41
- col[index, kidxs... , c, b] = T (0 )
42
- continue
12
+ if ! flipkernel
13
+ kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1
43
14
end
44
15
45
- # Copy the data over
46
- xval:: T = x[input_kw, input_kh, input_kd, c, b]
47
- col[index, kidxs... , c, b] = xval
16
+ # check out of bounds
17
+ if any ((w, h, d) .<= 0 .|| (w, h, d) .> input_size)
18
+ col[i, kw, kh, kd, c, b] = T (0 )
19
+ return nothing
48
20
end
21
+
22
+ xval:: T = x[w, h, d, c, b]
23
+ col[i, kw, kh, kd, c, b] = xval
49
24
end
50
25
51
26
return nothing
52
27
end
53
28
54
- function fold_kernel! (T :: Type , x, col, cdims, max_idx)
29
+ function fold_kernel! (x :: AbstractArray{T} , col, col_size, input_size, output_size, kernel_size, flipkernel, stride, pad_lo, dilation, max_idx) where {T}
55
30
index = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
56
31
57
- if index > max_idx
58
- return nothing
59
- end
32
+ @inbounds if index <= max_idx
33
+ i, kw, kh, kd, c, b = CartesianIndices (col_size)[index] . I # col indices
34
+ w, h, d = CartesianIndices (output_size)[i] . I # x indices
60
35
61
- # Extract those nice, compile-time constant type parameters from `cdims`.
62
- width, height, depth = NNlib. input_size (cdims)
63
- kernel_w, kernel_h, kernel_d = NNlib. kernel_size (cdims)
64
- C_in = NNlib. channels_in (cdims)
65
- pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = NNlib. padding (cdims)
66
- dil_w, dil_h, dil_d = NNlib. dilation (cdims)
67
- stride_w, stride_h, stride_d = NNlib. stride (cdims)
68
- output_size = NNlib. output_size (cdims)
36
+ # project
37
+ w, h, d = @. ((w, h, d) - 1 )* stride - pad_lo + 1 + ((kw, kh, kd) - 1 )* dilation
69
38
70
- I = CartesianIndices (output_size)
71
- w, h, d = I[index]. I # ouput spatial index indices
72
-
73
- # A helper function to project from output (w, h) to input (input_w, input_h)
74
- @inline project (idx, stride, pad) = (idx - 1 )* stride - pad + 1
75
-
76
- @inbounds for c in 1 : C_in, b in 1 : size (x, 5 )
77
- for kd in 1 : kernel_d,
78
- kh in 1 : kernel_h,
79
- kw in 1 : kernel_w
80
-
81
- input_kd = project (d, stride_d, pad_d_lo) + (kd - 1 )* dil_d
82
- input_kh = project (h, stride_h, pad_h_lo) + (kh - 1 )* dil_h
83
- input_kw = project (w, stride_w, pad_w_lo) + (kw - 1 )* dil_w
84
-
85
- out_of_bounds = (
86
- input_kd <= 0 || input_kd > depth ||
87
- input_kh <= 0 || input_kh > height ||
88
- input_kw <= 0 || input_kw > width
89
- )
90
- if out_of_bounds
91
- continue
39
+ # check out of bounds
40
+ if any ((w, h, d) .<= 0 .|| (w, h, d) .> input_size)
41
+ return nothing
92
42
end
93
43
94
- # Copy the data over
95
- kidxs = NNlib. kernel_index (kw, kh, kd, cdims)
96
- cval:: T = col[index, kidxs... , c, b]
97
- CUDA. @atomic x[input_kw, input_kh, input_kd, c, b] += cval
44
+ if ! flipkernel
45
+ kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1
98
46
end
47
+
48
+ cval:: T = col[i, kw, kh, kd, c, b]
49
+ CUDA. @atomic x[w, h, d, c, b] += cval
99
50
end
100
51
101
52
return nothing
@@ -106,23 +57,20 @@ function NNlib.unfold!(col::AnyCuArray{cT,3}, x::AnyCuArray{xT,5}, cdims::NNlib.
106
57
throw (DimensionMismatch (" unfold!() only accepts 3d convoluitional inputs" ))
107
58
end
108
59
109
- output_size = NNlib. output_size (cdims)
110
- kernel_w, kernel_h, kernel_d = NNlib. kernel_size (cdims)
60
+ input_size = NNlib. input_size (cdims)
111
61
C_in = NNlib. channels_in (cdims)
62
+ kernel_size = NNlib. kernel_size (cdims)
63
+ pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = NNlib. padding (cdims)
64
+ pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo)
65
+ dilation = NNlib. dilation (cdims)
66
+ stride = NNlib. stride (cdims)
67
+ output_size = NNlib. output_size (cdims)
68
+ flipkernel = NNlib. flipkernel (cdims)
69
+
70
+ col_reshaped = reshape (col, (prod (output_size), kernel_size... , C_in, :))
112
71
113
- # Reshape col for easy access.
114
- col_reshaped = reshape (col, (
115
- prod (output_size),
116
- # By input patch size
117
- kernel_w,
118
- kernel_h,
119
- kernel_d,
120
- C_in,
121
- size (x, 5 ),
122
- ))
123
-
124
- max_idx = prod (output_size)
125
- args = cT, col_reshaped, x, cdims, max_idx
72
+ max_idx = prod (size (col))
73
+ args = col_reshaped, x, size (col_reshaped), input_size, output_size, kernel_size, flipkernel, stride, pad_lo, dilation, max_idx
126
74
kernel = @cuda launch= false unfold_kernel! (args... )
127
75
config = launch_configuration (kernel. fun; max_threads= 256 )
128
76
threads = min (max_idx, config. threads)
@@ -139,23 +87,20 @@ function NNlib.fold!(x::AnyCuArray{xT,5}, col::AnyCuArray{cT,3}, cdims::NNlib.De
139
87
# going to accumulate into x
140
88
fill! (x, xT (0 ))
141
89
142
- output_size = NNlib. output_size (cdims)
143
- kernel_w, kernel_h, kernel_d = NNlib. kernel_size (cdims)
90
+ input_size = NNlib. input_size (cdims)
144
91
C_in = NNlib. channels_in (cdims)
92
+ kernel_size = NNlib. kernel_size (cdims)
93
+ pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = NNlib. padding (cdims)
94
+ pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo)
95
+ dilation = NNlib. dilation (cdims)
96
+ stride = NNlib. stride (cdims)
97
+ output_size = NNlib. output_size (cdims)
98
+ flipkernel = NNlib. flipkernel (cdims)
99
+
100
+ col_reshaped = reshape (col, (prod (output_size), kernel_size... , C_in, :))
145
101
146
- # Reshape col for easy access.
147
- col_reshaped = reshape (col, (
148
- prod (output_size),
149
- # input patch size
150
- kernel_w,
151
- kernel_h,
152
- kernel_d,
153
- C_in,
154
- size (x, 5 ),
155
- ))
156
-
157
- max_idx = prod (output_size)
158
- args = xT, x, col_reshaped, cdims, max_idx
102
+ max_idx = prod (size (col))
103
+ args = x, col_reshaped, size (col_reshaped), input_size, output_size, kernel_size, flipkernel, stride, pad_lo, dilation, max_idx
159
104
kernel = @cuda launch= false fold_kernel! (args... )
160
105
config = launch_configuration (kernel. fun; max_threads= 256 )
161
106
threads = min (max_idx, config. threads)
0 commit comments