Skip to content

Commit 9cc8a3c

Browse files
Merge pull request #31 from pxl-th/master
Add CUDA kernels for grid sampling
2 parents 7ed1255 + 6480727 commit 9cc8a3c

File tree

7 files changed

+121
-5
lines changed

7 files changed

+121
-5
lines changed

ext/NNlibCUDA/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111

1212
[compat]
1313
CUDA = "3.3.1"
14-
NNlib = "0.7.25"
14+
NNlib = "0.7.31"
1515
julia = "1.6"
1616

1717
[extras]

ext/NNlibCUDA/src/NNlibCUDA.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using Random, Statistics
77
const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N}
88

99
include("upsample.jl")
10+
include("sampling.jl")
1011
include("activations.jl")
1112
include("batchedmul.jl")
1213
include("scatter.jl")

ext/NNlibCUDA/src/cudnn/batchnorm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,4 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
121121
db .= vec(sum(dy, dims=rdims))
122122
end
123123
end
124-
124+

ext/NNlibCUDA/src/sampling.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
@inline function NNlib._safe_add!(dx::CuDeviceArray{T, 4}, value, ix, iy, c, n) where T
2+
@inbounds CUDA.@atomic dx[ix, iy, c, n] += value
3+
end
4+
5+
function grid_sample_kernel!(n_elem, output, input, grid, padding_mode)
6+
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
7+
if index < n_elem
8+
iW, iH, iC, _ = size(input)
9+
_, gW, gH, _ = size(grid)
10+
11+
w = index % gW + 1
12+
h = (index ÷ gW) % gH + 1
13+
n = index ÷ (gW * gH) + 1
14+
NNlib._grid_sample_kernel!(output, input, grid, padding_mode, w, h, n, iW, iH, iC)
15+
end
16+
nothing
17+
end
18+
19+
function ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, input, grid, padding_mode)
20+
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
21+
if index < n_elem
22+
iW, iH, iC, _ = size(input)
23+
_, gW, gH, _ = size(grid)
24+
25+
w = index % gW + 1
26+
h = (index ÷ gW) % gH + 1
27+
n = index ÷ (gW * gH) + 1
28+
NNlib._∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, padding_mode, w, h, n, iW, iH, iC)
29+
end
30+
nothing
31+
end
32+
33+
function NNlib.grid_sample(x::CuArray{T, 4}, grid::CuArray{V, 4}; padding_mode = :zeros) where {T, V}
34+
pad = Val(padding_mode)
35+
_, _, xC, xN = size(x)
36+
_, gW, gH, _ = size(grid)
37+
n_elem = gW * gH * xN
38+
y = similar(x, T, (gW, gH, xC, xN))
39+
40+
kernel = @cuda launch=false grid_sample_kernel!(n_elem, y, x, grid, pad)
41+
config = launch_configuration(kernel.fun; max_threads=256)
42+
threads = min(n_elem, config.threads)
43+
blocks = cld(n_elem, threads)
44+
kernel(n_elem, y, x, grid, pad; threads=threads, blocks=blocks)
45+
y
46+
end
47+
48+
function NNlib.∇grid_sample::CuArray{T, 4}, x::CuArray{T, 4}, grid::CuArray{V, 4}; padding_mode = :zeros) where {T, V}
49+
pad = Val(padding_mode)
50+
xN = size(x, 4)
51+
_, gW, gH, _ = size(grid)
52+
n_elem = gW * gH * xN
53+
dx, dgrid = CUDA.zeros(T, size(x)), similar(grid)
54+
55+
kernel = @cuda launch=false ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, x, grid, pad)
56+
config = launch_configuration(kernel.fun; max_threads=256)
57+
threads = min(n_elem, config.threads)
58+
blocks = cld(n_elem, threads)
59+
kernel(n_elem, dx, dgrid, Δ, x, grid, pad; threads=threads, blocks=blocks)
60+
dx, dgrid
61+
end

ext/NNlibCUDA/test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ include("softmax.jl")
1818
include("batchnorm.jl")
1919
include("scatter.jl")
2020
include("gather.jl")
21+
include("sampling.jl")
2122
end

ext/NNlibCUDA/test/sampling.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
@testset "Grid Sampling" begin
2+
for T in (Float32, Float64)
3+
x = ones(T, (2, 2, 1, 1))
4+
grid = Array{T}(undef, 2, 2, 2, 1)
5+
grid[:, 1, 1, 1] .= (-1, -1)
6+
grid[:, 2, 1, 1] .= (1, -1)
7+
grid[:, 1, 2, 1] .= (-1, 1)
8+
grid[:, 2, 2, 1] .= (1, 1)
9+
10+
∇grid_true = Array{T}(undef, size(grid))
11+
∇grid_true[:, :, 1, 1] = [[0.0, 0.0] [-0.5, 0.0]]
12+
∇grid_true[:, :, 2, 1] = [[0.0, -0.5] [-0.5, -0.5]]
13+
14+
x_gpu, grid_gpu = CuArray(x), CuArray(grid)
15+
16+
padding_mode = :zeros
17+
y_gpu = grid_sample(x_gpu, grid_gpu; padding_mode=padding_mode)
18+
@test x == collect(y_gpu)
19+
@test eltype(y_gpu) == T
20+
21+
external_grad = CUDA.ones(T, size(y_gpu))
22+
∇input, ∇grid = ∇grid_sample(external_grad, x_gpu, grid_gpu; padding_mode=padding_mode)
23+
@test x == collect(∇input)
24+
@test ∇grid_true == collect(∇grid)
25+
@test eltype(∇input) == T
26+
@test eltype(∇grid) == T
27+
28+
padding_mode = :border
29+
fill!(∇grid_true, 0.0)
30+
sampled = grid_sample(x_gpu, grid_gpu; padding_mode=padding_mode)
31+
@test x == collect(sampled)
32+
@test eltype(sampled) == T
33+
34+
∇input, ∇grid = ∇grid_sample(external_grad, x_gpu, grid_gpu; padding_mode=padding_mode)
35+
@test x == collect(∇input)
36+
@test ∇grid_true == collect(∇grid)
37+
@test eltype(∇input) == T
38+
@test eltype(∇grid) == T
39+
end
40+
end
41+
42+
@testset "Compare grid sampling with NNlib" begin
43+
w, h, c, n = 16, 16, 2, 4
44+
input = rand(Float64, w, h, c, n)
45+
grid = zeros(Float64, 2, w, h, n)
46+
@inbounds for xi in 1:w, yi in 1:h, ni in 1:n
47+
grid[1, xi, yi, ni] = (xi / w) * 2.0 - 1.0 + 0.01
48+
grid[2, xi, yi, ni] = (yi / h) * 2.0 - 1.0
49+
end
50+
for padding_mode in (:zeros, :border)
51+
gputest(grid_sample, input, grid; atol=1e-6, padding_mode=padding_mode)
52+
end
53+
end

ext/NNlibCUDA/test/test_utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ function gputest(f, xs...; checkgrad=true, atol=1e-10, kws...)
55
cpu_out = f(cpu_in...; kws...)
66
gpu_out = f(gpu_in...; kws...)
77
@test collect(cpu_out) collect(gpu_out)
8-
8+
99
if checkgrad
10-
cpu_grad = gradient((x...) -> sum(f(x...)), cpu_in...)
11-
gpu_grad = gradient((x...) -> sum(f(x...)), gpu_in...)
10+
cpu_grad = gradient((x...) -> sum(f(x...; kws...)), cpu_in...)
11+
gpu_grad = gradient((x...) -> sum(f(x...; kws...)), gpu_in...)
1212
for (cpu_g, gpu_g) in zip(cpu_grad, gpu_grad)
1313
if cpu_g === nothing
1414
@test gpu_g === nothing

0 commit comments

Comments
 (0)