Skip to content

Commit 885b7c5

Browse files
committed
Improve GradOp performance on GPU for multiple dims
1 parent 27188c6 commit 885b7c5

File tree

2 files changed

+96
-46
lines changed

2 files changed

+96
-46
lines changed
Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,54 @@
1-
function LinearOperatorCollection.grad!(res::vecT, img::vecT, shape, dim) where {vecT <: AbstractGPUVector}
2-
δ = zeros(Int, length(shape))
3-
δ[dim] = 1
4-
δ = Tuple(δ)
5-
di = CartesianIndex(δ)
6-
7-
gpu_call(reshape(res, shape .- δ), reshape(img,shape), di) do ctx, res_, img_, di_
8-
idx = @cartesianidx(res_)
9-
@inbounds res_[idx] = img_[idx] - img_[idx + di_]
10-
return nothing
1+
function LinearOperatorCollection.grad!(res::vecT, img::vecT, shape::NTuple{N,Int64}, di::CartesianIndex{N}) where {vecT <: AbstractGPUVector, N}
2+
res = reshape(res, shape .- Tuple(di))
3+
4+
if length(res) > 0
5+
gpu_call(grad_kernel!, res, reshape(img,shape), di)
116
end
12-
7+
138
return res
149
end
1510

16-
# adjoint of directional gradients
17-
function LinearOperatorCollection.grad_t!(res::vecT, g::vecT, shape::NTuple{N,Int64}, dim::Int64) where {T, vecT <: AbstractGPUVector{T}, N}
18-
δ = zeros(Int, length(shape))
19-
δ[dim] = 1
20-
δ = Tuple(δ)
21-
di = CartesianIndex(δ)
11+
function grad_kernel!(ctx, res, img, di)
12+
idx = @cartesianidx(res)
13+
@inbounds res[idx] = img[idx] - img[idx + di]
14+
return nothing
15+
end
2216

17+
# adjoint of directional gradients
18+
function LinearOperatorCollection.grad_t!(res::vecT, g::vecT, shape::NTuple{N,Int64}, di::CartesianIndex{N}) where {T, vecT <: AbstractGPUVector{T}, N}
2319
res_ = reshape(res,shape)
24-
g_ = reshape(g, shape .- δ)
20+
g_ = reshape(g, shape .- Tuple(di))
2521

2622
fill!(res, zero(T))
27-
gpu_call(res_, g_, di, elements = length(g)) do ctx, res_k, g_k, di_k
28-
idx = @cartesianidx(g_k)
29-
@inbounds res_k[idx] = g_k[idx]
30-
return nothing
23+
if length(g_) > 0
24+
gpu_call(grad_t_kernel_1!, res_, g_, di, elements = length(g))
25+
gpu_call(grad_t_kernel_2!, res_, g_, di, elements = length(g))
3126
end
27+
end
3228

33-
gpu_call(res_, g_, di, elements = length(g)) do ctx, res_k, g_k, di_k
34-
idx = @cartesianidx(g_k)
35-
@inbounds res_k[idx + di_k] -= g_k[idx]
36-
return nothing
37-
end
29+
function grad_t_kernel_1!(ctx, res, g, di)
30+
idx = @cartesianidx(g)
31+
@inbounds res[idx] += g[idx]
32+
return nothing
33+
end
34+
35+
function grad_t_kernel_2!(ctx, res, g, di)
36+
idx = @cartesianidx(g)
37+
@inbounds res[idx + di] -= g[idx]
38+
return nothing
3839
end
40+
41+
function LinearOperatorCollection.grad_t!(res::vecT, g::vecT, shape::NTuple{N,Int64}, dirs, dims, dim_ends, tmp) where {T, vecT <: AbstractGPUVector{T}, N}
42+
dim_start = 1
43+
res = reshape(res, shape)
44+
45+
fill!(res, zero(eltype(res)))
46+
for (i, di) in enumerate(dirs)
47+
g_ = reshape(view(g, dim_start:dim_ends[i]), shape .- Tuple(di))
48+
if length(g_) > 0
49+
gpu_call(grad_t_kernel_1!, res, g_, di, elements = length(g))
50+
gpu_call(grad_t_kernel_2!, res, g_, di, elements = length(g))
51+
end
52+
dim_start = dim_ends[i] + 1
53+
end
54+
end

src/GradientOp.jl

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,46 +16,80 @@ function GradientOp(::Type{T}; shape::NTuple{N,Int}, dims=1:length(shape), kwarg
1616
return GradientOpImpl(T, shape, dims; kwargs...)
1717
end
1818

19-
function GradientOpImpl(T::Type, shape::NTuple{N,Int}, dims; kwargs...) where N
20-
return vcat([GradientOpImpl(T, shape, dim; kwargs...) for dim dims]...)
19+
function GradientOpImpl(T::Type, shape::NTuple{N,Int}, dims; S = Vector{T}) where N
20+
dirs = CartesianIndex{N}[]
21+
cols = Int64[]
22+
for dim in dims
23+
δ = zeros(Int32, N)
24+
δ[dim] = 1
25+
δ = NTuple{N}(δ)
26+
di = CartesianIndex(δ)
27+
push!(dirs, di)
28+
push!(cols, div((shape[dim]-1)*prod(shape), shape[dim]))
29+
end
30+
dim_ends = accumulate(+, cols)
31+
32+
nrow = sum(cols)
33+
ncol = prod(shape)
34+
35+
tmp = S(undef, ncol)
36+
37+
return LinearOperator{T}(nrow, ncol, false, false,
38+
(res,x) -> (grad!(res,x,shape,dirs, dims, dim_ends)),
39+
(res,x) -> (grad_t!(res,x,shape,dirs, dims, dim_ends, tmp)),
40+
(res,x) -> (grad_t!(res,x,shape,dirs, dims, dim_ends, tmp)),
41+
S = S)
2142
end
2243

2344
function GradientOpImpl(T::Type, shape::NTuple{N,Int}, dim::Int; S = Vector{T}) where N
2445
nrow = div( (shape[dim]-1)*prod(shape), shape[dim] )
2546
ncol = prod(shape)
47+
δ = zeros(Int, length(shape))
48+
δ[dim] = 1
49+
δ = Tuple(δ)
50+
dir = CartesianIndex(δ)
2651
return LinearOperator{T}(nrow, ncol, false, false,
27-
(res,x) -> (grad!(res,x,shape,dim)),
28-
(res,x) -> (grad_t!(res,x,shape,dim)),
29-
(res,x) -> (grad_t!(res,x,shape,dim)),
52+
(res,x) -> (grad!(res,x,shape,dir)),
53+
(res,x) -> (grad_t!(res,x,shape,dir)),
54+
(res,x) -> (grad_t!(res,x,shape,dir)),
3055
S = S)
3156
end
3257

58+
function grad!(res::T, img::U, shape, dirs, dims, dim_ends) where {T<:AbstractVector, U<:AbstractVector}
59+
dim_start = 1
60+
61+
for (i, dir) in enumerate(dirs)
62+
grad!(view(res, dim_start:dim_ends[i]), img, shape, dir)
63+
dim_start = dim_ends[i] + 1
64+
end
65+
end
66+
3367
# directional gradients
34-
function grad!(res::T, img::U, shape, dim) where {T<:AbstractVector, U<:AbstractVector}
68+
function grad!(res::T, img::U, shape::NTuple{N,Int64}, di::CartesianIndex{N}) where {N, T<:AbstractVector, U<:AbstractVector}
3569
img_ = reshape(img,shape)
3670

37-
δ = zeros(Int, length(shape))
38-
δ[dim] = 1
39-
δ = Tuple(δ)
40-
di = CartesianIndex(δ)
41-
42-
res_ = reshape(res, shape .- δ)
71+
res_ = reshape(res, shape .- Tuple(di))
4372

4473
Threads.@threads for i CartesianIndices(res_)
4574
@inbounds res_[i] = img_[i] - img_[i + di]
4675
end
4776
end
4877

78+
function grad_t!(res::T, g::U, shape, dirs, dims, dims_end, tmp) where {T<:AbstractVector, U<:AbstractVector}
79+
dim_start = 1
4980

50-
# adjoint of directional gradients
51-
function grad_t!(res::T, g::U, shape::NTuple{N,Int64}, dim::Int64) where {T<:AbstractVector, U<:AbstractVector, N}
52-
δ = zeros(Int, length(shape))
53-
δ[dim] = 1
54-
δ = Tuple(δ)
55-
di = CartesianIndex(δ)
81+
fill!(res, zero(eltype(res)))
82+
for (i, dir) in enumerate(dirs)
83+
grad_t!(tmp, view(g, dim_start:dims_end[i]), shape, dir)
84+
dim_start = dims_end[i] + 1
85+
res .= res .+ tmp
86+
end
87+
end
5688

89+
# adjoint of directional gradients
90+
function grad_t!(res::T, g::U, shape::NTuple{N,Int64}, di::CartesianIndex{N}) where {N, T<:AbstractVector, U<:AbstractVector}
5791
res_ = reshape(res,shape)
58-
g_ = reshape(g, shape .- δ)
92+
g_ = reshape(g, shape .- Tuple(di))
5993

6094
res_ .= 0
6195
Threads.@threads for i CartesianIndices(g_)

0 commit comments

Comments
 (0)