Skip to content

Commit a040376

Browse files
committed
port remaining Ops to new interface and prepare for NFFTOp
1 parent 41fc693 commit a040376

File tree

14 files changed

+372
-92
lines changed

14 files changed

+372
-92
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,23 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1313
[weakdeps]
1414
Wavelets = "29a6e085-ba6d-5f35-a997-948ac2efa89a"
1515
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
16+
NFFT = "efe261a4-0d2b-5849-be55-fc731d526b0d"
1617

1718
[extensions]
1819
LinearOperatorWaveletExt = "Wavelets"
1920
LinearOperatorFFTWExt = "FFTW"
21+
LinearOperatorNFFTExt = "NFFT"
2022

2123
[compat]
2224
julia = "1.6"
2325
FFTW = "1.0"
2426
LinearOperators = "2.3.3"
2527
Reexport = "1.0"
2628
Wavelets = "0.9"
29+
NFFT = "0.13"
2730

2831
[extras]
2932
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3033

3134
[targets]
32-
test = ["Test", "FFTW", "Wavelets"]
35+
test = ["Test", "FFTW", "Wavelets", "NFFT"]

ext/LinearOperatorFFTWExt/DCTOp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function LinearOperatorCollection.constructLinearOperator(::Type{Op};
55
return DCTOpImpl(T, shape, dcttype)
66
end
77

8-
mutable struct DCTOpImpl{T} <: AbstractLinearOperatorFromCollection{T}
8+
mutable struct DCTOpImpl{T} <: DCTOp{T}
99
nrow :: Int
1010
ncol :: Int
1111
symmetric :: Bool

ext/LinearOperatorFFTWExt/DSTOp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function LinearOperatorCollection.constructLinearOperator(::Type{Op};
55
return DSTOpImpl(T, shape)
66
end
77

8-
mutable struct DSTOpImpl{T} <: AbstractLinearOperatorFromCollection{T}
8+
mutable struct DSTOpImpl{T} <: DSTOp{T}
99
nrow :: Int
1010
ncol :: Int
1111
symmetric :: Bool

ext/LinearOperatorFFTWExt/FFTOp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ function LinearOperatorCollection.constructLinearOperator(::Type{Op};
66
return FFTOpImpl(T, shape, shift; unitary, cuda)
77
end
88

9-
mutable struct FFTOpImpl{T} <: AbstractLinearOperatorFromCollection{T}
9+
mutable struct FFTOpImpl{T} <: FFTOp{T}
1010
nrow :: Int
1111
ncol :: Int
1212
symmetric :: Bool
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
module LinearOperatorNFFTExt
2+
3+
using LinearOperatorCollection, NFFT
4+
5+
include("NFFTOp.jl")
6+
7+
end

ext/LinearOperatorNFFTExt/NFFTOp.jl

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
export NFFTOpImpl
2+
import Base.adjoint
3+
4+
function LinearOperatorCollection.constructLinearOperator(::Type{Op};
5+
shape::Tuple, nodes::AbstractMatrix{T}, toeplitz=false, oversamplingFactor=1.25,
6+
kernelSize=3, kargs...) where Op <: NFFTOp{T} where T <: Number
7+
return NFFTOpImpl(T, shape, nodes; toeplitz, oversamplingFactor, kernelSize, kargs... )
8+
end
9+
10+
mutable struct NFFTOpImpl{T} <: NFFTOp{T}
11+
nrow :: Int
12+
ncol :: Int
13+
symmetric :: Bool
14+
hermitian :: Bool
15+
prod! :: Function
16+
tprod! :: Nothing
17+
ctprod! :: Function
18+
nprod :: Int
19+
ntprod :: Int
20+
nctprod :: Int
21+
args5 :: Bool
22+
use_prod5! :: Bool
23+
allocated5 :: Bool
24+
Mv5 :: Vector{T}
25+
Mtu5 :: Vector{T}
26+
plan
27+
toeplitz :: Bool
28+
end
29+
30+
LinearOperators.storage_type(op::NFFTOpImpl) = typeof(op.Mv5)
31+
32+
"""
33+
NFFTOpImpl(shape::Tuple, tr::Trajectory; kargs...)
34+
NFFTOpImpl(shape::Tuple, tr::AbstractMatrix; kargs...)
35+
36+
generates a `NFFTOpImpl` which evaluates the MRI Fourier signal encoding operator using the NFFT.
37+
38+
# Arguments:
39+
* `shape::NTuple{D,Int64}` - size of image to encode/reconstruct
40+
* `tr` - Either a `Trajectory` object, or a `ND x Nsamples` matrix for an ND-dimenensional (e.g. 2D or 3D) NFFT with `Nsamples` k-space samples
41+
* (`nodes=nothing`) - Array containg the trajectory nodes (redundant)
42+
* (`kargs`) - additional keyword arguments
43+
"""
44+
function NFFTOpImpl(shape::Tuple, tr::AbstractMatrix{T}; toeplitz=false, oversamplingFactor=1.25, kernelSize=3, kargs...) where {T}
45+
46+
plan = plan_nfft(tr, shape, m=kernelSize, σ=oversamplingFactor, precompute=NFFT.TENSOR,
47+
fftflags=FFTW.ESTIMATE, blocking=true)
48+
49+
return NFFTOpImpl{Complex{T}}(size(tr,2), prod(shape), false, false
50+
, (res,x) -> produ!(res,plan,x)
51+
, nothing
52+
, (res,y) -> ctprodu!(res,plan,y)
53+
, 0, 0, 0, false, false, false, Complex{T}[], Complex{T}[]
54+
, plan, toeplitz)
55+
end
56+
57+
function produ!(y::AbstractVector, plan::NFFT.NFFTPlan, x::AbstractVector)
58+
mul!(y, plan, reshape(x,plan.N))
59+
end
60+
61+
function ctprodu!(x::AbstractVector, plan::NFFT.NFFTPlan, y::AbstractVector)
62+
mul!(reshape(x, plan.N), adjoint(plan), y)
63+
end
64+
65+
66+
function Base.copy(S::NFFTOpImpl{T}) where {T}
67+
plan = copy(S.plan)
68+
return NFFTOpImpl{T}(size(plan.k,2), prod(plan.N), false, false
69+
, (res,x) -> produ!(res,plan,x)
70+
, nothing
71+
, (res,y) -> ctprodu!(res,plan,y)
72+
, 0, 0, 0, false, false, false, T[], T[]
73+
, plan, S.toeplitz)
74+
end
75+
76+
77+
78+
#########################################################################
79+
### Toeplitz Operator ###
80+
#########################################################################
81+
82+
mutable struct NFFTToeplitzNormalOp{T,D,W} <: AbstractLinearOperator{T}
83+
nrow :: Int
84+
ncol :: Int
85+
symmetric :: Bool
86+
hermitian :: Bool
87+
prod! :: Function
88+
tprod! :: Nothing
89+
ctprod! :: Nothing
90+
nprod :: Int
91+
ntprod :: Int
92+
nctprod :: Int
93+
args5 :: Bool
94+
use_prod5! :: Bool
95+
allocated5 :: Bool
96+
Mv5 :: Vector{T}
97+
Mtu5 :: Vector{T}
98+
shape::NTuple{D,Int}
99+
weights::W
100+
fftplan
101+
ifftplan
102+
λ::Array{T}
103+
xL1::Array{T,D}
104+
xL2::Array{T,D}
105+
end
106+
107+
LinearOperators.storage_type(op::NFFTToeplitzNormalOp) = typeof(op.Mv5)
108+
109+
function NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1::Array{T,D}, xL2::Array{T,D}) where {T,D}
110+
111+
function produ!(y, shape, fftplan, ifftplan, λ, xL1, xL2, x)
112+
xL1 .= 0
113+
x = reshape(x, shape)
114+
115+
xL1[CartesianIndices(x)] .= x
116+
mul!(xL2, fftplan, xL1)
117+
xL2 .*= λ
118+
mul!(xL1, ifftplan, xL2)
119+
120+
y .= vec(xL1[CartesianIndices(x)])
121+
return y
122+
end
123+
124+
return NFFTToeplitzNormalOp(prod(shape), prod(shape), false, false
125+
, (res,x) -> produ!(res, shape, fftplan, ifftplan, λ, xL1, xL2, x)
126+
, nothing
127+
, nothing
128+
, 0, 0, 0, false, false, false, T[], T[]
129+
, shape, W, fftplan, ifftplan, λ, xL1, xL2)
130+
end
131+
132+
function NFFTToeplitzNormalOp(S::NFFTOp{T}, W=opEye(T,size(S,1))) where {T}
133+
shape = S.plan.N
134+
135+
# plan the FFTs
136+
fftplan = plan_fft( zeros(T, 2 .* shape);flags=FFTW.MEASURE)
137+
ifftplan = plan_ifft(zeros(T, 2 .* shape);flags=FFTW.MEASURE)
138+
139+
# TODO extend the following function by weights
140+
# λ = calculateToeplitzKernel(shape, S.plan.k; m = S.plan.params.m, σ = S.plan.params.σ, window = S.plan.params.window, LUTSize = S.plan.params.LUTSize, fftplan = fftplan)
141+
142+
shape_os = 2 .* shape
143+
p = plan_nfft(typeof(S.plan.k), S.plan.k, shape_os; m = S.plan.params.m, σ = S.plan.params.σ,
144+
precompute=NFFT.POLYNOMIAL, fftflags=FFTW.ESTIMATE, blocking=true)
145+
eigMat = adjoint(p) * ( W * ones(T, size(S.plan.k,2)))
146+
λ = fftplan * fftshift(eigMat)
147+
148+
xL1 = Array{T}(undef, 2 .* shape)
149+
xL2 = similar(xL1)
150+
151+
return NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1, xL2)
152+
end
153+
154+
function LinearOperatorCollection.normalOperator(S::NFFTOpImpl{T}, W=opEye(T,size(S,1))) where T
155+
if S.toeplitz
156+
return NFFTToeplitzNormalOp(S,W)
157+
else
158+
return NormalOp(S,W)
159+
end
160+
end
161+
162+
function Base.copy(A::NFFTToeplitzNormalOp{T,D,W}) where {T,D,W}
163+
fftplan = plan_fft( zeros(T, 2 .* A.shape); flags=FFTW.MEASURE)
164+
ifftplan = plan_ifft(zeros(T, 2 .* A.shape); flags=FFTW.MEASURE)
165+
return NFFTToeplitzNormalOp(A.shape, A.weights, fftplan, ifftplan, A.λ, copy(A.xL1), copy(A.xL2))
166+
end

src/GradientOp.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,36 @@
1-
export GradientOp
1+
function LinearOperatorCollection.constructLinearOperator(::Type{Op};
2+
shape::Tuple, dim::Union{Nothing,Int64}=nothing) where Op <: GradientOp{T} where T <: Number
3+
if dim == nothing
4+
return GradientOpImpl(T, shape)
5+
else
6+
return GradientOpImpl(T, shape, dim)
7+
end
8+
end
9+
210

311
"""
4-
gradOp(T::Type, shape::NTuple{1,Int64})
12+
GradientOpImpl(T::Type, shape::NTuple{1,Int64})
513
614
1d gradient operator for an array of size `shape`
715
"""
8-
GradientOp(T::Type, shape::NTuple{1,Int64}) = GradientOp(T,shape,1)
16+
GradientOpImpl(T::Type, shape::NTuple{1,Int64}) = GradientOpImpl(T,shape,1)
917

1018
"""
11-
gradOp(T::Type, shape::NTuple{2,Int64})
19+
GradientOpImpl(T::Type, shape::NTuple{2,Int64})
1220
1321
2d gradient operator for an array of size `shape`
1422
"""
15-
function GradientOp(T::Type, shape::NTuple{2,Int64})
16-
return vcat( GradientOp(T,shape,1), GradientOp(T,shape,2) )
23+
function GradientOpImpl(T::Type, shape::NTuple{2,Int64})
24+
return vcat( GradientOpImpl(T,shape,1), GradientOpImpl(T,shape,2) )
1725
end
1826

1927
"""
20-
gradOp(T::Type, shape::NTuple{3,Int64})
28+
GradientOpImpl(T::Type, shape::NTuple{3,Int64})
2129
2230
3d gradient operator for an array of size `shape`
2331
"""
24-
function GradientOp(T::Type, shape::NTuple{3,Int64})
25-
return vcat( GradientOp(T,shape,1), GradientOp(T,shape,2), GradientOp(T,shape,3) )
32+
function GradientOpImpl(T::Type, shape::NTuple{3,Int64})
33+
return vcat( GradientOpImpl(T,shape,1), GradientOpImpl(T,shape,2), GradientOpImpl(T,shape,3) )
2634
end
2735

2836
"""
@@ -31,7 +39,7 @@ end
3139
directional gradient operator along the dimension `dim`
3240
for an array of size `shape`
3341
"""
34-
function GradientOp(T::Type, shape::NTuple{N,Int64}, dim::Int64) where N
42+
function GradientOpImpl(T::Type, shape::NTuple{N,Int64}, dim::Int64) where N
3543
nrow = div( (shape[dim]-1)*prod(shape), shape[dim] )
3644
ncol = prod(shape)
3745
return LinearOperator{T}(nrow, ncol, false, false,

src/LinearOperatorCollection.jl

Lines changed: 13 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import LinearAlgebra.BLAS: gemv, gemv!
66
import LinearAlgebra: BlasFloat, normalize!, norm, rmul!, lmul!
77
using SparseArrays
88
using Random
9-
#using CUDA
109

1110
using Reexport
1211
@reexport using Reexport
@@ -15,10 +14,6 @@ using Reexport
1514
LinearOperators.use_prod5!(op::opEye) = false
1615
LinearOperators.has_args5(op::opEye) = false
1716

18-
19-
const Trafo = Union{AbstractMatrix, AbstractLinearOperator, Nothing}
20-
const FuncOrNothing = Union{Function, Nothing}
21-
2217
# Helper function to wrap a prod into a 5-args mul
2318
function wrapProd(prod::Function)
2419
λ = (res, x, α, β) -> begin
@@ -31,68 +26,35 @@ function wrapProd(prod::Function)
3126
return λ
3227
end
3328

34-
include("GradientOp.jl")
35-
include("SamplingOp.jl")
36-
include("WeightingOp.jl")
37-
include("NormalOp.jl")
38-
39-
export linearOperator, linearOperatorList
40-
41-
export constructLinearOperator
42-
export AbstractLinearOperatorFromCollection, WaveletOp, FFTOp, DCTOp, DSTOp
29+
export linearOperatorList, constructLinearOperator
30+
export AbstractLinearOperatorFromCollection, WaveletOp, FFTOp, DCTOp, DSTOp, NFFTOp,
31+
SamplingOp, NormalOp, WeightingOp, GradientOp
4332

4433
abstract type AbstractLinearOperatorFromCollection{T} <: AbstractLinearOperator{T} end
4534
abstract type WaveletOp{T} <: AbstractLinearOperatorFromCollection{T} end
4635
abstract type FFTOp{T} <: AbstractLinearOperatorFromCollection{T} end
4736
abstract type DCTOp{T} <: AbstractLinearOperatorFromCollection{T} end
4837
abstract type DSTOp{T} <: AbstractLinearOperatorFromCollection{T} end
38+
abstract type NFFTOp{T} <: AbstractLinearOperatorFromCollection{T} end
39+
abstract type SamplingOp{T} <: AbstractLinearOperatorFromCollection{T} end
40+
abstract type NormalOp{T} <: AbstractLinearOperatorFromCollection{T} end
41+
abstract type WeightingOp{T} <: AbstractLinearOperatorFromCollection{T} end
42+
abstract type GradientOp{T} <: AbstractLinearOperatorFromCollection{T} end
4943

5044
function constructLinearOperator(::Type{<:AbstractLinearOperatorFromCollection}, args...; kargs...)
5145
error("Operator can't be constructed. You need to load another package!")
5246
end
5347

54-
linearOperator(op::Nothing, shape, T::Type=ComplexF32) = nothing
55-
5648
"""
5749
returns a list of currently implemented `LinearOperator`s
5850
"""
5951
function linearOperatorList()
60-
return ["DCT-II", "DCT-IV", "FFT", "DST", "Wavelet", "Gradient"]
52+
return subtypes(AbstractLinearOperatorFromCollection)
6153
end
6254

63-
"""
64-
linearOperator(op::AbstractString, shape)
65-
66-
returns the `LinearOperator` with name `op`.
67-
68-
# valid names
69-
* `"FFT"`
70-
* `"DCT-II"`
71-
* `"DCT-IV"`
72-
* `"DST"`
73-
* `"Wavelet"`
74-
* `"Gradient"`
75-
"""
76-
function linearOperator(op::AbstractString, shape, T::Type=ComplexF32)
77-
shape_ = tuple(shape...)
78-
if op == "FFT"
79-
trafo = FFTOp(T, shape_, false) #FFTOperator(shape)
80-
elseif op == "DCT-II"
81-
shape_ = tuple(shape[shape .!= 1]...)
82-
trafo = DCTOp(T, shape_, 2)
83-
elseif op == "DCT-IV"
84-
shape_ = tuple(shape[shape .!= 1]...)
85-
trafo = DCTOp(T, shape_, 4)
86-
elseif op == "DST"
87-
trafo = DSTOp(T, shape_)
88-
elseif op == "Wavelet"
89-
trafo = WaveletOp(T,shape_)
90-
elseif op=="Gradient"
91-
trafo = GradientOp(T,shape_)
92-
else
93-
error("Unknown transformation")
94-
end
95-
trafo
96-
end
55+
include("GradientOp.jl")
56+
include("SamplingOp.jl")
57+
include("WeightingOp.jl")
58+
include("NormalOp.jl")
9759

9860
end

0 commit comments

Comments
 (0)