Skip to content

Add parent-type as parameter of normal operator #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/Breakage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
with:
version: 1
arch: x64
- uses: actions/cache@v1
- uses: actions/cache@v4
env:
cache-name: cache-artifacts
with:
Expand Down
12 changes: 9 additions & 3 deletions ext/LinearOperatorNFFTExt/NFFTOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1::matT, xL2::m
, shape, W, fftplan, ifftplan, λ, xL1, xL2)
end

function NFFTToeplitzNormalOp(nfft::NFFTOp{T}, W=opEye(eltype(nfft), size(nfft, 1), S= LinearOperators.storage_type(nfft)); kwargs...) where {T}
function NFFTToeplitzNormalOp(nfft::NFFTOp{T}, W=nothing; kwargs...) where {T}
shape = nfft.plan.N

tmpVec = similar(nfft.Mv5, (2 .* shape)...)
Expand All @@ -147,7 +147,13 @@ function NFFTToeplitzNormalOp(nfft::NFFTOp{T}, W=opEye(eltype(nfft), size(nfft,
precompute=NFFT.POLYNOMIAL, fftflags=FFTW.ESTIMATE, blocking=true)
tmpOnes = similar(tmpVec, size(nfft.plan.k, 2))
tmpOnes .= one(T)
eigMat = adjoint(p) * ( W * tmpOnes)

if !isnothing(W)
eigMat = adjoint(p) * ( W * tmpOnes)
else
eigMat = adjoint(p) * (tmpOnes)
end

λ = fftplan * fftshift(eigMat)

xL1 = tmpVec
Expand All @@ -156,7 +162,7 @@ function NFFTToeplitzNormalOp(nfft::NFFTOp{T}, W=opEye(eltype(nfft), size(nfft,
return NFFTToeplitzNormalOp(shape, W, fftplan, ifftplan, λ, xL1, xL2)
end

function LinearOperatorCollection.normalOperator(S::NFFTOpImpl{T}, W = opEye(eltype(S), size(S, 1), S= LinearOperators.storage_type(S)); copyOpsFn = copy, kwargs...) where T
function LinearOperatorCollection.normalOperator(S::NFFTOpImpl{T}, W = nothing; copyOpsFn = copy, kwargs...) where T
if S.toeplitz
return NFFTToeplitzNormalOp(S,W; kwargs...)
else
Expand Down
13 changes: 9 additions & 4 deletions src/DiagOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,18 @@ function diagNormOpProd!(y, normalOps, idx, x)
return y
end

function LinearOperatorCollection.normalOperator(diag::DiagOp, W=opEye(eltype(diag), size(diag,1), S = LinearOperators.storage_type(diag)); copyOpsFn = copy, kwargs...)
T = promote_type(eltype(diag), eltype(W))
S = promote_type(LinearOperators.storage_type(diag), LinearOperators.storage_type(W))
function LinearOperatorCollection.normalOperator(diag::DiagOp, W=nothing; copyOpsFn = copy, kwargs...)
if !isnothing(W)
T = promote_type(eltype(diag), eltype(W))
S = promote_type(LinearOperators.storage_type(diag), LinearOperators.storage_type(W))
else
T = eltype(diag)
S = LinearOperators.storage_type(diag)
end
isconcretetype(S) || throw(LinearOperatorException("Storage types cannot be promoted to a concrete type"))
tmp = S(undef, diag.nrow)
tmp .= one(eltype(diag))
weights = W*tmp
weights = isnothing(W) ? tmp : W * tmp


if diag.equalOps
Expand Down
2 changes: 1 addition & 1 deletion src/LinearOperatorCollection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ abstract type DCTOp{T} <: AbstractLinearOperatorFromCollection{T} end
abstract type DSTOp{T} <: AbstractLinearOperatorFromCollection{T} end
abstract type NFFTOp{T} <: AbstractLinearOperatorFromCollection{T} end
abstract type SamplingOp{T} <: AbstractLinearOperatorFromCollection{T} end
abstract type NormalOp{T} <: AbstractLinearOperatorFromCollection{T} end
abstract type NormalOp{T,S} <: AbstractLinearOperatorFromCollection{T} end
abstract type GradientOp{T} <: AbstractLinearOperatorFromCollection{T} end
abstract type RadonOp{T} <: AbstractLinearOperatorFromCollection{T} end

Expand Down
22 changes: 14 additions & 8 deletions src/NormalOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,15 @@ Computes `adjoint(parent) * weights * parent`.
* `weights` - Optional weights for normal operator. Must already be of form `weights = adjoint.(w) .* w`

"""
function LinearOperatorCollection.NormalOp(::Type{T}; parent, weights = opEye(eltype(parent), size(parent, 1), S = storage_type(parent))) where T <: Number
function LinearOperatorCollection.NormalOp(::Type{T}; parent, weights = nothing) where T <: Number
return NormalOp(T, parent, weights)
end

function NormalOp(::Type{T}, parent, ::Nothing) where T
weights = opEye(eltype(parent), size(parent, 1), S = storage_type(parent))
return NormalOp(T, parent, weights)
end
NormalOp(::Union{Type{T}, Type{Complex{T}}}, parent, weights::AbstractVector{T}) where T = NormalOp(T, parent, WeightingOp(weights))

NormalOp(::Union{Type{T}, Type{Complex{T}}}, parent, weights::AbstractLinearOperator{T}; kwargs...) where T = NormalOpImpl(parent, weights)
NormalOp(::Union{Type{T}, Type{Complex{T}}}, parent, weights; kwargs...) where T = NormalOpImpl(parent, weights)

mutable struct NormalOpImpl{T,S,D,V} <: NormalOp{T}
mutable struct NormalOpImpl{T,S,D,V} <: NormalOp{T, S}
nrow :: Int
ncol :: Int
symmetric :: Bool
Expand Down Expand Up @@ -56,13 +52,23 @@ function NormalOpImpl(parent, weights)
tmp = S(undef, size(parent, 1))
return NormalOpImpl(parent, weights, tmp)
end
function NormalOpImpl(parent, weights::Nothing)
S = storage_type(parent)
tmp = S(undef, size(parent, 1))
return NormalOpImpl(parent, weights, tmp)
end

function NormalOpImpl(parent, weights, tmp)
function produ!(y, parent, weights, tmp, x)
mul!(tmp, parent, x)
mul!(tmp, weights, tmp) # This can be dangerous. We might need to create two tmp vectors
return mul!(y, adjoint(parent), tmp)
end
function produ!(y, parent, weights::Nothing, tmp, x)
mul!(tmp, parent, x)
return mul!(y, adjoint(parent), tmp)
end


return NormalOpImpl{eltype(parent), typeof(parent), typeof(weights), typeof(tmp)}(size(parent,2), size(parent,2), false, false
, (res,x) -> produ!(res, parent, weights, tmp, x)
Expand All @@ -81,6 +87,6 @@ end

Constructs a normal operator of the parent in an opinionated way, i.e. it tries to apply optimisations to the resulting operator.
"""
function normalOperator(parent, weights=opEye(eltype(parent), size(parent, 1), S= storage_type(parent)); kwargs...)
function normalOperator(parent, weights=nothing; kwargs...)
return NormalOp(eltype(storage_type((parent))); parent = parent, weights = weights)
end
2 changes: 1 addition & 1 deletion src/ProdOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ end
Fuses weights of `ẀeightingOp` by computing `adjoint.(weights) .* weights`
"""
normalOperator(S::ProdOp{T, <:WeightingOp, matT}; kwargs...) where {T, matT} = normalOperator(S.B, WeightingOp(adjoint.(S.A.weights) .* S.A.weights); kwargs...)
function normalOperator(S::ProdOp, W=opEye(eltype(S),size(S,1), S = storage_type(S)); kwargs...)
function normalOperator(S::ProdOp, W=nothing; kwargs...)
arrayType = storage_type(S)
tmp = arrayType(undef, size(S.A, 2))
return ProdNormalOp(S.B, normalOperator(S.A, W; kwargs...), tmp)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using RadonKA
using JLArrays

areTypesDefined = @isdefined arrayTypes
arrayTypes = areTypesDefined ? arrayTypes : [Array, JLArray]
arrayTypes = areTypesDefined ? arrayTypes : [Array] #, JLArray]

@testset "LinearOperatorCollection" begin
include("testNormalOp.jl")
Expand Down
2 changes: 1 addition & 1 deletion test/testOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ function testDiagOp(N=32,K=2;arrayType = Array)

@testset "Weighted Diag Normal" begin
w = rand(eltype(op1), size(op1, 1))
wop = WeightingOp(w)
wop = WeightingOp(arrayType(w))
prod1 = ProdOp(wop, op1)
prod2 = ProdOp(wop, op2)
prod3 = ProdOp(wop, op3)
Expand Down
Loading