Skip to content

Commit 5740fd6

Browse files
author
Max Freudenberg
committed
use array dimension to dispatch kernels
1 parent 06619ac commit 5740fd6

File tree

1 file changed

+40
-106
lines changed

1 file changed

+40
-106
lines changed

ext/NNlibCUDA/src/upsample.jl

Lines changed: 40 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,45 @@
5252
end
5353
end
5454

55+
function NNlib.upsample_linear_kernel!(y::CuArray{T,N}, x::CuArray{T,N}; align_corners=true) where {T,N}
56+
out_size = prod(size(y)[1:N-2])
57+
58+
if align_corners
59+
ratios = ntuple(i -> T((size(x,i)-1) / (size(y,i)-1)), N-2)
60+
else
61+
ratios = ntuple(i -> T(size(x,i) / size(y,i)), N-2)
62+
end
63+
64+
kernel = @cuda launch=false upsample_linear_cuda_kernel!(out_size, ratios..., x, y, align_corners)
65+
config = launch_configuration(kernel.fun; max_threads=256)
66+
threads = Base.min(out_size, config.threads)
67+
blocks = cld(out_size, threads)
68+
kernel(out_size, ratios..., x, y, align_corners; threads=threads, blocks=blocks)
69+
return y
70+
end
71+
72+
function NNlib.∇upsample_linear_kernel!(dx::CuArray{T,N}, Δ::CuArray{T,N}; align_corners=true) where {T,N}
73+
in_size = prod(size(Δ)[1:N-2])
74+
75+
if align_corners
76+
ratios = ntuple(i -> T((size(dx,i)-1) / (size(Δ,i)-1)), N-2) # reversed compared to forward pass
77+
else
78+
ratios = ntuple(i -> T(size(dx,i) / size(Δ,i)), N-2)
79+
end
80+
81+
kernel = @cuda launch=false ∇upsample_linear_cuda_kernel!(in_size, ratios..., Δ, dx, align_corners)
82+
config = launch_configuration(kernel.fun; max_threads=256)
83+
threads = Base.min(in_size, config.threads)
84+
blocks = cld(in_size, threads)
85+
kernel(in_size, ratios..., Δ, dx, align_corners; threads=threads, blocks=blocks)
86+
return dx
87+
end
88+
5589

5690
###########
5791
# linear
5892
###########
59-
function upsample_linear_wcn_kernel!(n_elem, rwidth, x, y, align_corners)
93+
function upsample_linear_cuda_kernel!(n_elem, rwidth, x::CuDeviceArray{<:Any, 3}, y::CuDeviceArray{<:Any, 3}, align_corners)
6094
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
6195

6296
if index < n_elem
@@ -86,7 +120,7 @@ function upsample_linear_wcn_kernel!(n_elem, rwidth, x, y, align_corners)
86120
end
87121

88122
# Δ is the gradient backpropagated from downstream layers
89-
function upsample_linear_wcn_kernel!(n_elem, rwidth, Δ, dx, align_corners)
123+
function upsample_linear_cuda_kernel!(n_elem, rwidth, Δ::CuDeviceArray{<:Any, 3}, dx::CuDeviceArray{<:Any, 3}, align_corners)
90124
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
91125

92126
if index < n_elem
@@ -115,44 +149,11 @@ function ∇upsample_linear_wcn_kernel!(n_elem, rwidth, Δ, dx, align_corners)
115149
return nothing
116150
end
117151

118-
function NNlib.upsample_linear_wcn!(y::CuArray{T,3}, x::CuArray{T,3}; align_corners=true) where T
119-
out_size = size(y)[1] # w
120-
121-
if align_corners
122-
ratios = ntuple(i -> T((size(x,i)-1) / (size(y,i)-1)), 1)
123-
else
124-
ratios = ntuple(i -> T(size(x,i) / size(y,i)), 1)
125-
end
126-
127-
kernel = @cuda launch=false upsample_linear_wcn_kernel!(out_size, ratios..., x, y, align_corners)
128-
config = launch_configuration(kernel.fun; max_threads=256)
129-
threads = Base.min(out_size, config.threads)
130-
blocks = cld(out_size, threads)
131-
kernel(out_size, ratios..., x, y, align_corners; threads=threads, blocks=blocks)
132-
return y
133-
end
134-
135-
function NNlib.∇upsample_linear_wcn!(dx::CuArray{T,3}, Δ::CuArray{T,3}; align_corners=true) where T
136-
in_size = size(Δ)[1]
137-
if align_corners
138-
ratios = ntuple(i -> T((size(dx,i)-1) / (size(Δ,i)-1)), 1) # reversed compared to forward pass
139-
else
140-
ratios = ntuple(i -> T(size(dx,i) / size(Δ,i)), 1)
141-
end
142-
143-
kernel = @cuda launch=false ∇upsample_linear_wcn_kernel!(in_size, ratios..., Δ, dx, align_corners)
144-
config = launch_configuration(kernel.fun; max_threads=256)
145-
threads = Base.min(in_size, config.threads)
146-
blocks = cld(in_size, threads)
147-
kernel(in_size, ratios..., Δ, dx, align_corners; threads=threads, blocks=blocks)
148-
return dx
149-
end
150-
151152

152153
###########
153154
# bilinear
154155
###########
155-
function upsample_bilinear_whcn_kernel!(n_elem, rwidth, rheight, x, y, align_corners)
156+
function upsample_linear_cuda_kernel!(n_elem, rwidth, rheight, x::CuDeviceArray{<:Any, 4}, y::CuDeviceArray{<:Any, 4}, align_corners)
156157
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
157158

158159
if index < n_elem
@@ -194,7 +195,7 @@ function upsample_bilinear_whcn_kernel!(n_elem, rwidth, rheight, x, y, align_cor
194195
end
195196

196197
# Δ is the gradient backpropagated from downstream layers
197-
function upsample_bilinear_whcn_kernel!(n_elem, rwidth, rheight, Δ, dx, align_corners)
198+
function upsample_linear_cuda_kernel!(n_elem, rwidth, rheight, Δ::CuDeviceArray{<:Any, 4}, dx::CuDeviceArray{<:Any, 4}, align_corners)
198199
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
199200

200201
if index < n_elem
@@ -237,44 +238,11 @@ function ∇upsample_bilinear_whcn_kernel!(n_elem, rwidth, rheight, Δ, dx, alig
237238
return nothing
238239
end
239240

240-
function NNlib.upsample_bilinear_whcn!(y::CuArray{T,4}, x::CuArray{T,4}; align_corners=true) where T
241-
out_size = prod(size(y)[1:2]) # w*h
242-
243-
if align_corners
244-
ratios = ntuple(i -> T((size(x,i)-1) / (size(y,i)-1)), 2)
245-
else
246-
ratios = ntuple(i -> T(size(x,i) / size(y,i)), 2)
247-
end
248-
249-
kernel = @cuda launch=false upsample_bilinear_whcn_kernel!(out_size, ratios..., x, y, align_corners)
250-
config = launch_configuration(kernel.fun; max_threads=256)
251-
threads = Base.min(out_size, config.threads)
252-
blocks = cld(out_size, threads)
253-
kernel(out_size, ratios..., x, y, align_corners; threads=threads, blocks=blocks)
254-
return y
255-
end
256-
257-
function NNlib.∇upsample_bilinear_whcn!(dx::CuArray{T,4}, Δ::CuArray{T,4}; align_corners=true) where T
258-
in_size = prod(size(Δ)[1:2])
259-
if align_corners
260-
ratios = ntuple(i -> T((size(dx,i)-1) / (size(Δ,i)-1)), 2) # reversed compared to forward pass
261-
else
262-
ratios = ntuple(i -> T(size(dx,i) / size(Δ,i)), 2)
263-
end
264-
265-
kernel = @cuda launch=false ∇upsample_bilinear_whcn_kernel!(in_size, ratios..., Δ, dx, align_corners)
266-
config = launch_configuration(kernel.fun; max_threads=256)
267-
threads = Base.min(in_size, config.threads)
268-
blocks = cld(in_size, threads)
269-
kernel(in_size, ratios..., Δ, dx, align_corners; threads=threads, blocks=blocks)
270-
return dx
271-
end
272-
273241

274242
###########
275243
# trilinear
276244
###########
277-
function upsample_trilinear_whdcn_kernel!(n_elem, rwidth, rheight, rdepth, x, y, align_corners)
245+
function upsample_linear_cuda_kernel!(n_elem, rwidth, rheight, rdepth, x::CuDeviceArray{<:Any, 5}, y::CuDeviceArray{<:Any, 5}, align_corners)
278246
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
279247

280248
if index < n_elem
@@ -337,7 +305,7 @@ function upsample_trilinear_whdcn_kernel!(n_elem, rwidth, rheight, rdepth, x, y,
337305
end
338306

339307
# Δ is the gradient backpropagated from downstream layers
340-
function upsample_trilinear_whdcn_kernel!(n_elem, rwidth, rheight, rdepth, Δ, dx, align_corners)
308+
function upsample_linear_cuda_kernel!(n_elem, rwidth, rheight, rdepth, Δ::CuDeviceArray{<:Any, 5}, dx::CuDeviceArray{<:Any, 5}, align_corners)
341309
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
342310

343311
if index < n_elem
@@ -389,37 +357,3 @@ function ∇upsample_trilinear_whdcn_kernel!(n_elem, rwidth, rheight, rdepth, Δ
389357
end # if
390358
return nothing
391359
end
392-
393-
function NNlib.upsample_trilinear_whdcn!(y::CuArray{T,5}, x::CuArray{T,5}; align_corners=true) where T
394-
out_size = prod(size(y)[1:3]) # w*h*d
395-
396-
if align_corners
397-
ratios = ntuple(i -> T((size(x,i)-1) / (size(y,i)-1)), 3)
398-
else
399-
ratios = ntuple(i -> T(size(x,i) / size(y,i)), 3)
400-
end
401-
402-
kernel = @cuda launch=false upsample_trilinear_whdcn_kernel!(out_size, ratios..., x, y, align_corners)
403-
config = launch_configuration(kernel.fun; max_threads=256)
404-
threads = Base.min(out_size, config.threads)
405-
blocks = cld(out_size, threads)
406-
kernel(out_size, ratios..., x, y, align_corners; threads=threads, blocks=blocks)
407-
return y
408-
end
409-
410-
function NNlib.∇upsample_trilinear_whdcn!(dx::CuArray{T,5}, Δ::CuArray{T,5}; align_corners=true) where T
411-
in_size = prod(size(Δ)[1:3])
412-
413-
if align_corners
414-
ratios = ntuple(i -> T((size(dx,i)-1) / (size(Δ,i)-1)), 3) # reversed compared to forward pass
415-
else
416-
ratios = ntuple(i -> T(size(dx,i) / size(Δ,i)), 3)
417-
end
418-
419-
kernel = @cuda launch=false ∇upsample_trilinear_whdcn_kernel!(in_size, ratios..., Δ, dx, align_corners)
420-
config = launch_configuration(kernel.fun; max_threads=256)
421-
threads = Base.min(in_size, config.threads)
422-
blocks = cld(in_size, threads)
423-
kernel(in_size, ratios..., Δ, dx, align_corners; threads=threads, blocks=blocks)
424-
return dx
425-
end

0 commit comments

Comments
 (0)