Skip to content

Commit 11e4f48

Browse files
author
Nikola Janjusevic
committed
fold/unfold kernel rewrite
1 parent 73cae99 commit 11e4f48

File tree

1 file changed

+53
-108
lines changed

1 file changed

+53
-108
lines changed

ext/NNlibCUDA/src/fold.jl

Lines changed: 53 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,52 @@
11

2-
function unfold_kernel!(T::Type, col, x, cdims, max_idx)
2+
function unfold_kernel!(col::AbstractArray{T}, x, col_size, input_size, output_size, kernel_size, flipkernel, stride, pad_lo, dilation, max_idx) where {T}
33
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
44

5-
if index > max_idx
6-
return nothing
7-
end
8-
9-
# Extract those nice, compile-time constant type parameters from `cdims`.
10-
width, height, depth = NNlib.input_size(cdims)
11-
kernel_w, kernel_h, kernel_d = NNlib.kernel_size(cdims)
12-
C_in = NNlib.channels_in(cdims)
13-
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = NNlib.padding(cdims)
14-
dil_w, dil_h, dil_d = NNlib.dilation(cdims)
15-
stride_w, stride_h, stride_d = NNlib.stride(cdims)
16-
output_size = NNlib.output_size(cdims)
17-
18-
I = CartesianIndices(output_size)
19-
w, h, d = I[index].I # ouput spatial index indices
20-
21-
# A helper function to project from output (w, h) to input (input_w, input_h)
22-
@inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1
5+
@inbounds if index <= max_idx
6+
i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices
7+
w, h, d = CartesianIndices(output_size)[i].I # x indices
238

24-
@inbounds for c in 1:C_in, b in 1:size(x,5)
25-
for kd in 1:kernel_d,
26-
kh in 1:kernel_h,
27-
kw in 1:kernel_w
9+
# project
10+
w, h, d = @. ((w, h, d) - 1)*stride - pad_lo + 1 + ((kw, kh, kd) - 1)*dilation
2811

29-
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d
30-
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
31-
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
32-
33-
kidxs = NNlib.kernel_index(kw, kh, kd, cdims)
34-
35-
out_of_bounds = (
36-
input_kd <= 0 || input_kd > depth ||
37-
input_kh <= 0 || input_kh > height ||
38-
input_kw <= 0 || input_kw > width
39-
)
40-
if out_of_bounds
41-
col[index, kidxs..., c, b] = T(0)
42-
continue
12+
if !flipkernel
13+
kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1
4314
end
4415

45-
# Copy the data over
46-
xval::T = x[input_kw, input_kh, input_kd, c, b]
47-
col[index, kidxs..., c, b] = xval
16+
# check out of bounds
17+
if any((w, h, d) .<= 0 .|| (w, h, d) .> input_size)
18+
col[i, kw, kh, kd, c, b] = T(0)
19+
return nothing
4820
end
21+
22+
xval::T = x[w, h, d, c, b]
23+
col[i, kw, kh, kd, c, b] = xval
4924
end
5025

5126
return nothing
5227
end
5328

54-
function fold_kernel!(T::Type, x, col, cdims, max_idx)
29+
function fold_kernel!(x::AbstractArray{T}, col, col_size, input_size, output_size, kernel_size, flipkernel, stride, pad_lo, dilation, max_idx) where {T}
5530
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
5631

57-
if index > max_idx
58-
return nothing
59-
end
32+
@inbounds if index <= max_idx
33+
i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices
34+
w, h, d = CartesianIndices(output_size)[i].I # x indices
6035

61-
# Extract those nice, compile-time constant type parameters from `cdims`.
62-
width, height, depth = NNlib.input_size(cdims)
63-
kernel_w, kernel_h, kernel_d = NNlib.kernel_size(cdims)
64-
C_in = NNlib.channels_in(cdims)
65-
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = NNlib.padding(cdims)
66-
dil_w, dil_h, dil_d = NNlib.dilation(cdims)
67-
stride_w, stride_h, stride_d = NNlib.stride(cdims)
68-
output_size = NNlib.output_size(cdims)
36+
# project
37+
w, h, d = @. ((w, h, d) - 1)*stride - pad_lo + 1 + ((kw, kh, kd) - 1)*dilation
6938

70-
I = CartesianIndices(output_size)
71-
w, h, d = I[index].I # ouput spatial index indices
72-
73-
# A helper function to project from output (w, h) to input (input_w, input_h)
74-
@inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1
75-
76-
@inbounds for c in 1:C_in, b in 1:size(x, 5)
77-
for kd in 1:kernel_d,
78-
kh in 1:kernel_h,
79-
kw in 1:kernel_w
80-
81-
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d
82-
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
83-
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
84-
85-
out_of_bounds = (
86-
input_kd <= 0 || input_kd > depth ||
87-
input_kh <= 0 || input_kh > height ||
88-
input_kw <= 0 || input_kw > width
89-
)
90-
if out_of_bounds
91-
continue
39+
# check out of bounds
40+
if any((w, h, d) .<= 0 .|| (w, h, d) .> input_size)
41+
return nothing
9242
end
9343

94-
# Copy the data over
95-
kidxs = NNlib.kernel_index(kw, kh, kd, cdims)
96-
cval::T = col[index, kidxs..., c, b]
97-
CUDA.@atomic x[input_kw, input_kh, input_kd, c, b] += cval
44+
if !flipkernel
45+
kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1
9846
end
47+
48+
cval::T = col[i, kw, kh, kd, c, b]
49+
CUDA.@atomic x[w, h, d, c, b] += cval
9950
end
10051

10152
return nothing
@@ -106,23 +57,20 @@ function NNlib.unfold!(col::AnyCuArray{cT,3}, x::AnyCuArray{xT,5}, cdims::NNlib.
10657
throw(DimensionMismatch("unfold!() only accepts 3d convoluitional inputs"))
10758
end
10859

109-
output_size = NNlib.output_size(cdims)
110-
kernel_w, kernel_h, kernel_d = NNlib.kernel_size(cdims)
60+
input_size = NNlib.input_size(cdims)
11161
C_in = NNlib.channels_in(cdims)
62+
kernel_size = NNlib.kernel_size(cdims)
63+
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = NNlib.padding(cdims)
64+
pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo)
65+
dilation = NNlib.dilation(cdims)
66+
stride = NNlib.stride(cdims)
67+
output_size = NNlib.output_size(cdims)
68+
flipkernel = NNlib.flipkernel(cdims)
69+
70+
col_reshaped = reshape(col, (prod(output_size), kernel_size..., C_in, :))
11271

113-
# Reshape col for easy access.
114-
col_reshaped = reshape(col, (
115-
prod(output_size),
116-
# By input patch size
117-
kernel_w,
118-
kernel_h,
119-
kernel_d,
120-
C_in,
121-
size(x, 5),
122-
))
123-
124-
max_idx = prod(output_size)
125-
args = cT, col_reshaped, x, cdims, max_idx
72+
max_idx = prod(size(col))
73+
args = col_reshaped, x, size(col_reshaped), input_size, output_size, kernel_size, flipkernel, stride, pad_lo, dilation, max_idx
12674
kernel = @cuda launch=false unfold_kernel!(args...)
12775
config = launch_configuration(kernel.fun; max_threads=256)
12876
threads = min(max_idx, config.threads)
@@ -139,23 +87,20 @@ function NNlib.fold!(x::AnyCuArray{xT,5}, col::AnyCuArray{cT,3}, cdims::NNlib.De
13987
# going to accumulate into x
14088
fill!(x, xT(0))
14189

142-
output_size = NNlib.output_size(cdims)
143-
kernel_w, kernel_h, kernel_d = NNlib.kernel_size(cdims)
90+
input_size = NNlib.input_size(cdims)
14491
C_in = NNlib.channels_in(cdims)
92+
kernel_size = NNlib.kernel_size(cdims)
93+
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = NNlib.padding(cdims)
94+
pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo)
95+
dilation = NNlib.dilation(cdims)
96+
stride = NNlib.stride(cdims)
97+
output_size = NNlib.output_size(cdims)
98+
flipkernel = NNlib.flipkernel(cdims)
99+
100+
col_reshaped = reshape(col, (prod(output_size), kernel_size..., C_in, :))
145101

146-
# Reshape col for easy access.
147-
col_reshaped = reshape(col, (
148-
prod(output_size),
149-
# input patch size
150-
kernel_w,
151-
kernel_h,
152-
kernel_d,
153-
C_in,
154-
size(x, 5),
155-
))
156-
157-
max_idx = prod(output_size)
158-
args = xT, x, col_reshaped, cdims, max_idx
102+
max_idx = prod(size(col))
103+
args = x, col_reshaped, size(col_reshaped), input_size, output_size, kernel_size, flipkernel, stride, pad_lo, dilation, max_idx
159104
kernel = @cuda launch=false fold_kernel!(args...)
160105
config = launch_configuration(kernel.fun; max_threads=256)
161106
threads = min(max_idx, config.threads)

0 commit comments

Comments
 (0)