Skip to content

Commit 73cae99

Browse files
author
Nikola Janjusevic
committed
added fold/unfold and gpu tests
1 parent df70552 commit 73cae99

File tree

4 files changed

+204
-0
lines changed

4 files changed

+204
-0
lines changed

ext/NNlibCUDA/src/NNlibCUDA.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ include("activations.jl")
1212
include("batchedadjtrans.jl")
1313
include("batchedmul.jl")
1414
include("ctc.jl")
15+
include("fold.jl")
1516
include("scatter.jl")
1617
include("gather.jl")
1718
include("utils.jl")

ext/NNlibCUDA/src/fold.jl

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
2+
function unfold_kernel!(T::Type, col, x, cdims, max_idx)
3+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
4+
5+
if index > max_idx
6+
return nothing
7+
end
8+
9+
# Extract those nice, compile-time constant type parameters from `cdims`.
10+
width, height, depth = NNlib.input_size(cdims)
11+
kernel_w, kernel_h, kernel_d = NNlib.kernel_size(cdims)
12+
C_in = NNlib.channels_in(cdims)
13+
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = NNlib.padding(cdims)
14+
dil_w, dil_h, dil_d = NNlib.dilation(cdims)
15+
stride_w, stride_h, stride_d = NNlib.stride(cdims)
16+
output_size = NNlib.output_size(cdims)
17+
18+
I = CartesianIndices(output_size)
19+
w, h, d = I[index].I # ouput spatial index indices
20+
21+
# A helper function to project from output (w, h) to input (input_w, input_h)
22+
@inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1
23+
24+
@inbounds for c in 1:C_in, b in 1:size(x,5)
25+
for kd in 1:kernel_d,
26+
kh in 1:kernel_h,
27+
kw in 1:kernel_w
28+
29+
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d
30+
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
31+
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
32+
33+
kidxs = NNlib.kernel_index(kw, kh, kd, cdims)
34+
35+
out_of_bounds = (
36+
input_kd <= 0 || input_kd > depth ||
37+
input_kh <= 0 || input_kh > height ||
38+
input_kw <= 0 || input_kw > width
39+
)
40+
if out_of_bounds
41+
col[index, kidxs..., c, b] = T(0)
42+
continue
43+
end
44+
45+
# Copy the data over
46+
xval::T = x[input_kw, input_kh, input_kd, c, b]
47+
col[index, kidxs..., c, b] = xval
48+
end
49+
end
50+
51+
return nothing
52+
end
53+
54+
function fold_kernel!(T::Type, x, col, cdims, max_idx)
55+
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
56+
57+
if index > max_idx
58+
return nothing
59+
end
60+
61+
# Extract those nice, compile-time constant type parameters from `cdims`.
62+
width, height, depth = NNlib.input_size(cdims)
63+
kernel_w, kernel_h, kernel_d = NNlib.kernel_size(cdims)
64+
C_in = NNlib.channels_in(cdims)
65+
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = NNlib.padding(cdims)
66+
dil_w, dil_h, dil_d = NNlib.dilation(cdims)
67+
stride_w, stride_h, stride_d = NNlib.stride(cdims)
68+
output_size = NNlib.output_size(cdims)
69+
70+
I = CartesianIndices(output_size)
71+
w, h, d = I[index].I # ouput spatial index indices
72+
73+
# A helper function to project from output (w, h) to input (input_w, input_h)
74+
@inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1
75+
76+
@inbounds for c in 1:C_in, b in 1:size(x, 5)
77+
for kd in 1:kernel_d,
78+
kh in 1:kernel_h,
79+
kw in 1:kernel_w
80+
81+
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d
82+
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
83+
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
84+
85+
out_of_bounds = (
86+
input_kd <= 0 || input_kd > depth ||
87+
input_kh <= 0 || input_kh > height ||
88+
input_kw <= 0 || input_kw > width
89+
)
90+
if out_of_bounds
91+
continue
92+
end
93+
94+
# Copy the data over
95+
kidxs = NNlib.kernel_index(kw, kh, kd, cdims)
96+
cval::T = col[index, kidxs..., c, b]
97+
CUDA.@atomic x[input_kw, input_kh, input_kd, c, b] += cval
98+
end
99+
end
100+
101+
return nothing
102+
end
103+
104+
function NNlib.unfold!(col::AnyCuArray{cT,3}, x::AnyCuArray{xT,5}, cdims::NNlib.DenseConvDims) where {cT, xT}
105+
if NNlib.spatial_dims(cdims) != 3
106+
throw(DimensionMismatch("unfold!() only accepts 3d convoluitional inputs"))
107+
end
108+
109+
output_size = NNlib.output_size(cdims)
110+
kernel_w, kernel_h, kernel_d = NNlib.kernel_size(cdims)
111+
C_in = NNlib.channels_in(cdims)
112+
113+
# Reshape col for easy access.
114+
col_reshaped = reshape(col, (
115+
prod(output_size),
116+
# By input patch size
117+
kernel_w,
118+
kernel_h,
119+
kernel_d,
120+
C_in,
121+
size(x, 5),
122+
))
123+
124+
max_idx = prod(output_size)
125+
args = cT, col_reshaped, x, cdims, max_idx
126+
kernel = @cuda launch=false unfold_kernel!(args...)
127+
config = launch_configuration(kernel.fun; max_threads=256)
128+
threads = min(max_idx, config.threads)
129+
blocks = cld(max_idx, threads)
130+
kernel(args...; threads=threads, blocks=blocks)
131+
return col
132+
end
133+
134+
function NNlib.fold!(x::AnyCuArray{xT,5}, col::AnyCuArray{cT,3}, cdims::NNlib.DenseConvDims) where {xT, cT}
135+
if NNlib.spatial_dims(cdims) != 3
136+
throw(DimensionMismatch("fold!() only accepts 3d convoluitional inputs"))
137+
end
138+
139+
# going to accumulate into x
140+
fill!(x, xT(0))
141+
142+
output_size = NNlib.output_size(cdims)
143+
kernel_w, kernel_h, kernel_d = NNlib.kernel_size(cdims)
144+
C_in = NNlib.channels_in(cdims)
145+
146+
# Reshape col for easy access.
147+
col_reshaped = reshape(col, (
148+
prod(output_size),
149+
# input patch size
150+
kernel_w,
151+
kernel_h,
152+
kernel_d,
153+
C_in,
154+
size(x, 5),
155+
))
156+
157+
max_idx = prod(output_size)
158+
args = xT, x, col_reshaped, cdims, max_idx
159+
kernel = @cuda launch=false fold_kernel!(args...)
160+
config = launch_configuration(kernel.fun; max_threads=256)
161+
threads = min(max_idx, config.threads)
162+
blocks = cld(max_idx, threads)
163+
kernel(args...; threads=threads, blocks=blocks)
164+
return x
165+
end
166+

ext/NNlibCUDA/test/fold.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
2+
@testset "fold" begin
3+
# Test for agreement between CPU/GPU versions, across a variety of kwargs
4+
options = Dict{Any, Any}.((
5+
(), (:dilation => 2), (:flipkernel => true), (:stride => 2),
6+
(:padding => 1),
7+
(:padding => (1,0)),
8+
(:padding => (0,1)),
9+
(:padding => (2,3)),
10+
))
11+
12+
C_in = 3
13+
C_out = 4
14+
batch_size = 1
15+
16+
@testset "spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
17+
for opts in options
18+
if :padding in keys(opts)
19+
padding = opts[:padding]
20+
if 1 < length(padding) && length(padding) != 2spatial_rank
21+
opts[:padding] = ntuple(i -> padding[mod1(i,2)] .+ 2div(i-1,2), 2spatial_rank)
22+
end
23+
end
24+
25+
x = rand(Float64, fill(8, spatial_rank)..., C_in, batch_size)
26+
w = rand(Float64, fill(2, spatial_rank)..., C_in, C_out)
27+
cdims = DenseConvDims(x, w; opts...)
28+
y = unfold(x, cdims)
29+
30+
# test equivalence of fold/unfold across GPU/CPU
31+
gputest(x -> NNlib.unfold(x, cdims), x)
32+
gputest(y -> NNlib.fold(y, size(x), cdims), y)
33+
end
34+
end
35+
end
36+

ext/NNlibCUDA/test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ include("batchedmul.jl")
1515
include("upsample.jl")
1616
include("conv.jl")
1717
include("ctc.jl")
18+
include("fold.jl")
1819
include("pooling.jl")
1920
include("softmax.jl")
2021
include("batchnorm.jl")

0 commit comments

Comments
 (0)