Skip to content

Commit 41fc693

Browse files
committed
finalize design
1 parent 062e03f commit 41fc693

File tree

6 files changed

+53
-35
lines changed

6 files changed

+53
-35
lines changed

ext/LinearOperatorFFTWExt/DCTOp.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
export DCTOp
1+
export DCTOpImpl
22

3-
mutable struct DCTOp{T} <: AbstractLinearOperator{T}
3+
function LinearOperatorCollection.constructLinearOperator(::Type{Op};
4+
shape::Tuple, dcttype::Int) where Op <: DCTOp{T} where T <: Number
5+
return DCTOpImpl(T, shape, dcttype)
6+
end
7+
8+
mutable struct DCTOpImpl{T} <: AbstractLinearOperatorFromCollection{T}
49
nrow :: Int
510
ncol :: Int
611
symmetric :: Bool
@@ -20,19 +25,19 @@ mutable struct DCTOp{T} <: AbstractLinearOperator{T}
2025
dcttype::Int
2126
end
2227

23-
LinearOperators.storage_type(op::DCTOp) = typeof(op.Mv5)
28+
LinearOperators.storage_type(op::DCTOpImpl) = typeof(op.Mv5)
2429

2530
"""
26-
DCTOp(T::Type, shape::Tuple, dcttype=2)
31+
DCTOpImpl(T::Type, shape::Tuple, dcttype=2)
2732
28-
returns a `DCTOp <: AbstractLinearOperator` which performs a DCT on a given input array.
33+
returns a `DCTOpImpl <: AbstractLinearOperator` which performs a DCT on a given input array.
2934
3035
# Arguments:
3136
* `T::Type` - type of the array to transform
3237
* `shape::Tuple` - size of the array to transform
3338
* `dcttype` - type of DCT (currently `2` and `4` are supported)
3439
"""
35-
function DCTOp(T::Type, shape::Tuple, dcttype=2)
40+
function DCTOpImpl(T::Type, shape::Tuple, dcttype=2)
3641

3742
tmp=Array{Complex{real(T)}}(undef, shape)
3843
if dcttype == 2
@@ -50,7 +55,7 @@ function DCTOp(T::Type, shape::Tuple, dcttype=2)
5055
error("DCT type $(dcttype) not supported")
5156
end
5257

53-
return DCTOp{T}(prod(shape), prod(shape), false, false,
58+
return DCTOpImpl{T}(prod(shape), prod(shape), false, false,
5459
prod!, nothing, tprod!,
5560
0, 0, 0, true, false, true, T[], T[],
5661
plan, dcttype)
@@ -68,6 +73,6 @@ function dct_multiply4(res::Vector{T}, plan::P, x::Vector{T}, tmp::Array{T,D}, f
6873
res .= factor.*vec(tmp)
6974
end
7075

71-
function Base.copy(S::DCTOp)
72-
return DCTOp(eltype(S), size(S.plan), S.dcttype)
76+
function Base.copy(S::DCTOpImpl)
77+
return DCTOpImpl(eltype(S), size(S.plan), S.dcttype)
7378
end

ext/LinearOperatorFFTWExt/DSTOp.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
export DSTOp
1+
export DSTOpImpl
22

3-
mutable struct DSTOp{T} <: AbstractLinearOperator{T}
3+
function LinearOperatorCollection.constructLinearOperator(::Type{Op};
4+
shape::Tuple, shift::Bool=true, unitary::Bool=true, cuda::Bool=false) where Op <: DSTOp{T} where T <: Number
5+
return DSTOpImpl(T, shape)
6+
end
7+
8+
mutable struct DSTOpImpl{T} <: AbstractLinearOperatorFromCollection{T}
49
nrow :: Int
510
ncol :: Int
611
symmetric :: Bool
@@ -20,26 +25,26 @@ mutable struct DSTOp{T} <: AbstractLinearOperator{T}
2025
iplan
2126
end
2227

23-
LinearOperators.storage_type(op::DSTOp) = typeof(op.Mv5)
28+
LinearOperators.storage_type(op::DSTOpImpl) = typeof(op.Mv5)
2429

2530
"""
26-
DSTOp(T::Type, shape::Tuple)
31+
DSTOpImpl(T::Type, shape::Tuple)
2732
2833
returns a `LinearOperator` which performs a DST on a given input array.
2934
3035
# Arguments:
3136
* `T::Type` - type of the array to transform
3237
* `shape::Tuple` - size of the array to transform
3338
"""
34-
function DSTOp(T::Type, shape::Tuple)
39+
function DSTOpImpl(T::Type, shape::Tuple)
3540
tmp=Array{Complex{real(T)}}(undef, shape)
3641

3742
plan = FFTW.plan_r2r!(tmp,FFTW.RODFT10)
3843
iplan = FFTW.plan_r2r!(tmp,FFTW.RODFT01)
3944

4045
w = weights(shape, T)
4146

42-
return DSTOp{T}(prod(shape), prod(shape), true, false
47+
return DSTOpImpl{T}(prod(shape), prod(shape), true, false
4348
, (res,x) -> dst_multiply!(res,plan,x,tmp,w)
4449
, nothing
4550
, (res,x) -> dst_bmultiply!(res,iplan,x,tmp,w)
@@ -72,6 +77,6 @@ function dst_bmultiply!(res::Vector{T}, plan::P, x::Vector{T}, tmp::Array{T,D},
7277
res[:] .= vec(tmp)./(8*length(tmp))
7378
end
7479

75-
function Base.copy(S::DSTOp)
76-
return DSTOp(eltype(S), size(S.plan))
80+
function Base.copy(S::DSTOpImpl)
81+
return DSTOpImpl(eltype(S), size(S.plan))
7782
end

ext/LinearOperatorFFTWExt/FFTOp.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
export FFTOp
1+
export FFTOpImpl
22
import Base.copy
33

4-
mutable struct FFTOp{T} <: AbstractLinearOperator{T}
4+
function LinearOperatorCollection.constructLinearOperator(::Type{Op};
5+
shape::Tuple, shift::Bool=true, unitary::Bool=true, cuda::Bool=false) where Op <: FFTOp{T} where T <: Number
6+
return FFTOpImpl(T, shape, shift; unitary, cuda)
7+
end
8+
9+
mutable struct FFTOpImpl{T} <: AbstractLinearOperatorFromCollection{T}
510
nrow :: Int
611
ncol :: Int
712
symmetric :: Bool
@@ -23,10 +28,10 @@ mutable struct FFTOp{T} <: AbstractLinearOperator{T}
2328
unitary::Bool
2429
end
2530

26-
LinearOperators.storage_type(op::FFTOp) = typeof(op.Mv5)
31+
LinearOperators.storage_type(op::FFTOpImpl) = typeof(op.Mv5)
2732

2833
"""
29-
FFTOp(T::Type, shape::Tuple, shift=true, unitary=true)
34+
FFTOpImpl(T::Type, shape::Tuple, shift=true, unitary=true)
3035
3136
returns an operator which performs an FFT on Arrays of type T
3237
@@ -36,7 +41,7 @@ returns an operator which performs an FFT on Arrays of type T
3641
* (`shift=true`) - if true, fftshifts are performed
3742
* (`unitary=true`) - if true, FFT is normalized such that it is unitary
3843
"""
39-
function FFTOp(T::Type, shape::NTuple{D,Int64}, shift::Bool=true; unitary::Bool=true, cuda::Bool=false) where D
44+
function FFTOpImpl(T::Type, shape::NTuple{D,Int64}, shift::Bool=true; unitary::Bool=true, cuda::Bool=false) where D
4045

4146
#tmpVec = cuda ? CuArray{T}(undef,shape) : Array{Complex{real(T)}}(undef, shape)
4247
tmpVec = Array{Complex{real(T)}}(undef, shape)
@@ -54,7 +59,7 @@ function FFTOp(T::Type, shape::NTuple{D,Int64}, shift::Bool=true; unitary::Bool=
5459
let shape_=shape, plan_=plan, iplan_=iplan, tmpVec_=tmpVec, facF_=facF, facB_=facB
5560

5661
if shift
57-
return FFTOp{T}(prod(shape), prod(shape), false, false
62+
return FFTOpImpl{T}(prod(shape), prod(shape), false, false
5863
, (res, x) -> fft_multiply_shift!(res, plan_, x, shape_, facF_, tmpVec_)
5964
, nothing
6065
, (res, x) -> fft_multiply_shift!(res, iplan_, x, shape_, facB_, tmpVec_)
@@ -64,7 +69,7 @@ function FFTOp(T::Type, shape::NTuple{D,Int64}, shift::Bool=true; unitary::Bool=
6469
, shift
6570
, unitary)
6671
else
67-
return FFTOp{T}(prod(shape), prod(shape), false, false
72+
return FFTOpImpl{T}(prod(shape), prod(shape), false, false
6873
, (res, x) -> fft_multiply!(res, plan_, x, facF_, tmpVec_)
6974
, nothing
7075
, (res, x) -> fft_multiply!(res, iplan_, x, facB_, tmpVec_)
@@ -91,6 +96,6 @@ function fft_multiply_shift!(res::AbstractVector{T}, plan::P, x::AbstractVector{
9196
end
9297

9398

94-
function Base.copy(S::FFTOp)
95-
return FFTOp(eltype(S), size(S.plan), S.shift, unitary=S.unitary)
99+
function Base.copy(S::FFTOpImpl)
100+
return FFTOpImpl(eltype(S), size(S.plan), S.shift, unitary=S.unitary)
96101
end

ext/LinearOperatorFFTWExt/LinearOperatorFFTWExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module LinearOperatorFFTWExt
22

3+
using LinearOperatorCollection, FFTW
4+
35
include("FFTOp.jl")
46
include("DCTOp.jl")
57
include("DSTOp.jl")

ext/LinearOperatorWaveletExt/WaveletOp.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
export WaveletOp
1+
export WaveletOpImpl
22

3-
import LinearOperatorCollection: constructLinearOperator
3+
using LinearOperatorCollection, Wavelets
44

5-
function constructLinearOperator(::Type{Op}; eltype::Type, shape::Tuple, wt=wavelet(WT.db2)) where Op <: WaveletOp{T} where T <: Number
6-
return WaveletOpImpl(eltype, shape, wt)
5+
function LinearOperatorCollection.constructLinearOperator(::Type{Op};
6+
shape::Tuple, wt=wavelet(WT.db2)) where Op <: WaveletOp{T} where T <: Number
7+
return WaveletOpImpl(T, shape, wt)
78
end
89

910
"""

test/testOperators.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ function testDCT1d(N=32)
44
for i=1:5
55
x .+= rand()*cos.(rand(1:N^2)*collect(1:N^2)) .+ 1im*rand()*cos.(rand(1:N^2)*collect(1:N^2))
66
end
7-
D1 = DCTOp(ComplexF64,(N^2,),2)
7+
D1 = constructLinearOperator(DCTOp{ComplexF64}, shape=(N^2,), dcttype=2)
88
D2 = sqrt(2/N^2)*[cos(pi/(N^2)*j*(k+0.5)) for j=0:N^2-1,k=0:N^2-1]
99
D2[1,:] .*= 1/sqrt(2)
10-
D3 = DCTOp(ComplexF64,(N^2,),4)
10+
D3 = constructLinearOperator(DCTOp{ComplexF64}, shape=(N^2,), dcttype=4)
1111
D4 = sqrt(2/N^2)*[cos(pi/(N^2)*(j+0.5)*(k+0.5)) for j=0:N^2-1,k=0:N^2-1]
1212

1313
y1 = D1*x
@@ -32,7 +32,7 @@ function testFFT1d(N=32,shift=true)
3232
for i=1:5
3333
x .+= rand()*cos.(rand(1:N^2)*collect(1:N^2))
3434
end
35-
D1 = FFTOp(ComplexF64,(N^2,),shift)
35+
D1 = constructLinearOperator(FFTOp{ComplexF64}, shape=(N^2,), shift=shift)
3636
D2 = 1.0/N*[exp(-2*pi*im*j*k/N^2) for j=0:N^2-1,k=0:N^2-1]
3737

3838
y1 = D1*x
@@ -58,7 +58,7 @@ function testFFT2d(N=32,shift=true)
5858
for i=1:5
5959
x .+= rand()*cos.(rand(1:N^2)*collect(1:N^2))
6060
end
61-
D1 = FFTOp(ComplexF64,(N,N),shift)
61+
D1 = constructLinearOperator(FFTOp{ComplexF64}, shape=(N,N), shift=shift)
6262

6363
idx = CartesianIndices((N,N))[collect(1:N^2)]
6464
D2 = 1.0/N*[ exp(-2*pi*im*((idx[j][1]-1)*(idx[k][1]-1)+(idx[j][2]-1)*(idx[k][2]-1))/N) for j=1:N^2, k=1:N^2 ]
@@ -156,7 +156,7 @@ end
156156

157157
function testWavelet(M=64,N=60)
158158
x = rand(M,N)
159-
WOp = WaveletOp(Float64,(M,N))
159+
WOp = constructLinearOperator(WaveletOp{Float64}, shape=(M,N))
160160
x_wavelet = WOp*vec(x)
161161
x_reco = reshape( adjoint(WOp)*x_wavelet, M, N)
162162

0 commit comments

Comments
 (0)