Skip to content

Commit dc00cf7

Browse files
committed
port more general GradientOp from RegLS to LinOpCol
1 parent 3c95e72 commit dc00cf7

File tree

3 files changed

+74
-91
lines changed

3 files changed

+74
-91
lines changed

ext/LinearOperatorNFFTExt/NFFTOp.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
import Base.adjoint
22

3-
function LinearOperatorCollection.createLinearOperator(::Type{Op}; kargs...) where Op <: NFFTOp{T} where T <: Number
4-
return NFFTOp(T; kargs...)
5-
end
6-
73
function LinearOperatorCollection.NFFTOp(::Type{T};
84
shape::Tuple, nodes::AbstractMatrix{U}, toeplitz=false, oversamplingFactor=1.25,
95
kernelSize=3, kargs...) where {U <: Number, T <: Number}

src/GradientOp.jl

Lines changed: 38 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,117 +1,72 @@
11
function LinearOperatorCollection.GradientOp(::Type{T};
2-
shape::Tuple, dim::Union{Nothing,Int64}=nothing) where T <: Number
3-
if dim == nothing
2+
shape::Tuple, dims=nothing) where T <: Number
3+
if dims == nothing
44
return GradientOpImpl(T, shape)
55
else
6-
return GradientOpImpl(T, shape, dim)
6+
return GradientOpImpl(T, shape, dims)
77
end
88
end
99

10-
11-
"""
12-
GradientOpImpl(T::Type, shape::NTuple{1,Int64})
13-
14-
1d gradient operator for an array of size `shape`
1510
"""
16-
GradientOpImpl(T::Type, shape::NTuple{1,Int64}) = GradientOpImpl(T,shape,1)
11+
GradOp(T::Type, shape::NTuple{N,Int64})
1712
13+
Nd gradient operator for an array of size `shape`
1814
"""
19-
GradientOpImpl(T::Type, shape::NTuple{2,Int64})
20-
21-
2d gradient operator for an array of size `shape`
22-
"""
23-
function GradientOpImpl(T::Type, shape::NTuple{2,Int64})
24-
return vcat( GradientOpImpl(T,shape,1), GradientOpImpl(T,shape,2) )
15+
function GradientOpImpl(T::Type, shape)
16+
shape = typeof(shape) <: Number ? (shape,) : shape # convert Number to Tuple
17+
return vcat([GradientOpImpl(T, shape, i) for i eachindex(shape)]...)
2518
end
2619

2720
"""
28-
GradientOpImpl(T::Type, shape::NTuple{3,Int64})
21+
GradOp(T::Type, shape::NTuple{N,Int64}, dims)
2922
30-
3d gradient operator for an array of size `shape`
31-
"""
32-
function GradientOpImpl(T::Type, shape::NTuple{3,Int64})
33-
return vcat( GradientOpImpl(T,shape,1), GradientOpImpl(T,shape,2), GradientOpImpl(T,shape,3) )
34-
end
35-
36-
"""
37-
gradOp(T::Type, shape::NTuple{N,Int64}, dim::Int64) where N
38-
39-
directional gradient operator along the dimension `dim`
23+
directional gradient operator along the dimensions `dims`
4024
for an array of size `shape`
4125
"""
42-
function GradientOpImpl(T::Type, shape::NTuple{N,Int64}, dim::Int64) where N
26+
function GradientOpImpl(T::Type, shape::NTuple{N,Int64}, dims) where N
27+
return vcat([GradientOpImpl(T, shape, dim) for dim dims]...)
28+
end
29+
function GradientOpImpl(T::Type, shape::NTuple{N,Int64}, dim::Integer) where N
4330
nrow = div( (shape[dim]-1)*prod(shape), shape[dim] )
4431
ncol = prod(shape)
4532
return LinearOperator{T}(nrow, ncol, false, false,
46-
(res,x) -> (grad!(res,x,shape,dim) ),
47-
(res,x) -> (grad_t!(res,x,shape,dim) ),
33+
(res,x) -> (grad!(res,x,shape,dim) ),
34+
(res,x) -> (grad_t!(res,x,shape,dim) ),
4835
nothing )
4936
end
5037

5138
# directional gradients
52-
function grad!(res::T, img::U, shape::NTuple{1,Int64}, dim::Int64) where {T<:AbstractVector,U<:AbstractVector}
53-
res .= img[1:end-1].-img[2:end]
54-
end
55-
56-
function grad!(res::T, img::U, shape::NTuple{2,Int64}, dim::Int64) where {T<:AbstractVector,U<:AbstractVector}
57-
img = reshape(img,shape)
39+
function grad!(res::T, img::U, shape, dim) where {T<:AbstractVector, U<:AbstractVector}
40+
img_ = reshape(img,shape)
5841

59-
if dim==1
60-
res .= vec(img[1:end-1,:].-img[2:end,:])
61-
else
62-
res .= vec(img[:,1:end-1].-img[:,2:end])
63-
end
64-
end
42+
δ = zeros(Int, length(shape))
43+
δ[dim] = 1
44+
δ = Tuple(δ)
45+
di = CartesianIndex(δ)
6546

66-
function grad!(res::T,img::U, shape::NTuple{3,Int64}, dim::Int64) where {T<:AbstractVector,U<:AbstractVector}
67-
img = reshape(img,shape)
47+
res_ = reshape(res, shape .- δ)
6848

69-
if dim==1
70-
res .= vec(img[1:end-1,:,:].-img[2:end,:,:])
71-
elseif dim==2
72-
res.= vec(img[:,1:end-1,:].-img[:,2:end,:])
73-
else
74-
res.= vec(img[:,:,1:end-1].-img[:,:,2:end])
49+
Threads.@threads for i CartesianIndices(res_)
50+
@inbounds res_[i] = img_[i] - img_[i + di]
7551
end
7652
end
7753

54+
7855
# adjoint of directional gradients
79-
function grad_t!(res::T, g::U, shape::NTuple{1,Int64}, dim::Int64) where {T<:AbstractVector,U<:AbstractVector}
80-
res .= zero(eltype(g))
81-
res[1:shape[1]-1] .= g
82-
res[2:shape[1]] .-= g
83-
end
56+
function grad_t!(res::T, g::U, shape::NTuple{N,Int64}, dim::Int64) where {T<:AbstractVector, U<:AbstractVector, N}
57+
δ = zeros(Int, length(shape))
58+
δ[dim] = 1
59+
δ = Tuple(δ)
60+
di = CartesianIndex(δ)
8461

85-
function grad_t!(res::T, g::U, shape::NTuple{2,Int64}, dim::Int64) where {T<:AbstractVector,U<:AbstractVector}
86-
res .= zero(eltype(g))
8762
res_ = reshape(res,shape)
63+
g_ = reshape(g, shape .- δ)
8864

89-
if dim==1
90-
g = reshape(g,shape[1]-1,shape[2])
91-
res_[1:shape[1]-1,:] .= g
92-
res_[2:shape[1],:] .-= g
93-
else
94-
g = reshape(g,shape[1],shape[2]-1)
95-
res_[:,1:shape[2]-1] .= g
96-
res_[:,2:shape[2]] .-= g
65+
res_ .= 0
66+
Threads.@threads for i CartesianIndices(g_)
67+
@inbounds res_[i] = g_[i]
9768
end
98-
end
99-
100-
function grad_t!(res::T, g::U, shape::NTuple{3,Int64}, dim::Int64) where {T<:AbstractVector,U<:AbstractVector}
101-
res .= zero(eltype(g))
102-
res_ = reshape(res,shape)
103-
104-
if dim==1
105-
g = reshape(g,shape[1]-1,shape[2],shape[3])
106-
res_[1:shape[1]-1,:,:] .= g
107-
res_[2:shape[1],:,:] .-= g
108-
elseif dim==2
109-
g = reshape(g,shape[1],shape[2]-1,shape[3])
110-
res_[:,1:shape[2]-1,:] .= g
111-
res_[:,2:shape[2],:] .-= g
112-
else
113-
g = reshape(g,shape[1],shape[2],shape[3]-1)
114-
res_[:,:,1:shape[3]-1] .= g
115-
res_[:,:,2:shape[3]] .-= g
69+
Threads.@threads for i CartesianIndices(g_)
70+
@inbounds res_[i + di] -= g_[i]
11671
end
117-
end
72+
end

test/testOperators.jl

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ end
102102

103103
function testGradOp1d(N=512)
104104
x = rand(N)
105-
G = GradientOp(eltype(x), shape=size(x))
105+
G = GradientOp(eltype(x); shape=size(x))
106106
G0 = Bidiagonal(ones(N),-ones(N-1), :U)[1:N-1,:]
107107

108108
y = G*x
@@ -112,12 +112,12 @@ function testGradOp1d(N=512)
112112
xr = transpose(G)*y
113113
xr0 = transpose(G0)*y0
114114

115-
@test norm(y - y0) / norm(y0) 0 atol=0.001
115+
@test norm(xr - xr0) / norm(xr0) 0 atol=0.001
116116
end
117117

118118
function testGradOp2d(N=64)
119119
x = repeat(1:N,1,N)
120-
G = GradientOp(eltype(x), shape=size(x))
120+
G = GradientOp(eltype(x); shape=size(x))
121121
G_1d = Bidiagonal(ones(N),-ones(N-1), :U)[1:N-1,:]
122122

123123
y = G*vec(x)
@@ -133,6 +133,37 @@ function testGradOp2d(N=64)
133133
@test norm(xr - xr0) / norm(xr0) 0 atol=0.001
134134
end
135135

136+
function testDirectionalGradOp(N=64)
137+
x = rand(ComplexF64,N,N)
138+
G1 = GradientOp(eltype(x); shape=size(x), dims=1)
139+
G2 = GradientOp(eltype(x); shape=size(x), dims=2)
140+
G_1d = Bidiagonal(ones(N),-ones(N-1), :U)[1:N-1,:]
141+
142+
y1 = G1*vec(x)
143+
y2 = G2*vec(x)
144+
y1_ref = zeros(ComplexF64, N-1,N)
145+
y2_ref = zeros(ComplexF64, N, N-1)
146+
for i=1:N
147+
y1_ref[:,i] .= G_1d*x[:,i]
148+
y2_ref[i,:] .= G_1d*x[i,:]
149+
end
150+
151+
@test norm(y1-vec(y1_ref)) / norm(y1_ref) 0 atol=0.001
152+
@test norm(y2-vec(y2_ref)) / norm(y2_ref) 0 atol=0.001
153+
154+
x1r = transpose(G1)*y1
155+
x2r = transpose(G2)*y2
156+
157+
x1r_ref = zeros(ComplexF64, N,N)
158+
x2r_ref = zeros(ComplexF64, N,N)
159+
for i=1:N
160+
x1r_ref[:,i] .= transpose(G_1d)*y1_ref[:,i]
161+
x2r_ref[i,:] .= transpose(G_1d)*y2_ref[i,:]
162+
end
163+
@test norm(x1r-vec(x1r_ref)) / norm(x1r_ref) 0 atol=0.001
164+
@test norm(x2r-vec(x2r_ref)) / norm(x2r_ref) 0 atol=0.001
165+
end
166+
136167
function testSampling(N=64)
137168
x = rand(ComplexF64,N,N)
138169
# index-based sampling
@@ -275,7 +306,8 @@ end
275306
testWeighting(512)
276307
@info "test GradientOp"
277308
testGradOp1d(512)
278-
testGradOp2d(64)
309+
testGradOp2d(64)
310+
testDirectionalGradOp(64)
279311
@info "test SamplingOp"
280312
testSampling(64)
281313
@info "test WaveletOp"

0 commit comments

Comments
 (0)