diff --git a/src/ADMM.jl b/src/ADMM.jl index 4bffb0bc..6109990a 100644 --- a/src/ADMM.jl +++ b/src/ADMM.jl @@ -1,8 +1,9 @@ export ADMM -mutable struct ADMM{matT,opT,R,ropT,P,vecT,rvecT,preconT,rT} <: AbstractPrimalDualSolver where {vecT <: AbstractVector{Union{rT, Complex{rT}}}, rvecT <: AbstractVector{rT}} +mutable struct ADMM{matT,N,opT,R,ropT,P,vecT,rvecT,preconT,rT} <: AbstractPrimalDualSolver where {vecT <: AbstractVector{Union{rT, Complex{rT}}}, rvecT <: AbstractVector{rT}} # operators and regularization A::matT + shape::NTuple{N, Int64} reg::Vector{R} regTrafo::Vector{ropT} proj::Vector{P} @@ -81,6 +82,7 @@ function ADMM(A , relTol::Real = eps(real(eltype(AHA))) , tolInner::Real = 1e-5 , verbose = false + , shape = (size(AHA, 2),) ) T = eltype(AHA) @@ -135,7 +137,7 @@ function ADMM(A # normalization parameters reg = normalize(ADMM, normalizeReg, reg, A, nothing) - return ADMM(A,reg,regTrafo,proj,AHA,β,β_y,x,xᵒˡᵈ,z,zᵒˡᵈ,u,uᵒˡᵈ,precon,rho,iterations + return ADMM(A,shape,reg,regTrafo,proj,AHA,β,β_y,x,xᵒˡᵈ,z,zᵒˡᵈ,u,uᵒˡᵈ,precon,rho,iterations ,iterationsCG,cgStateVars,rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,rT(0),Δ,rT(absTol),rT(relTol),rT(tolInner),normalizeReg,vary_rho,verbose) end @@ -198,7 +200,7 @@ function iterate(solver::ADMM, iteration=1) cg!(solver.x, AHA, solver.β, Pl = solver.precon, maxiter = solver.iterationsCG, reltol = solver.tolInner, statevars = solver.cgStateVars, verbose = solver.verbose) for proj in solver.proj - prox!(proj, solver.x) + prox!(proj, reshape(solver.x, solver.shape)) end # proximal map for regularization terms @@ -212,7 +214,7 @@ function iterate(solver::ADMM, iteration=1) mul!(solver.z[i], solver.regTrafo[i], solver.x) solver.z[i] .+= solver.u[i] if solver.ρ[i] != 0 - prox!(solver.reg[i], solver.z[i], λ(solver.reg[i])/2solver.ρ[i]) # λ is divided by 2 to match the ISTA-type algorithms + prox!(solver.reg[i], reshape(solver.z[i], solver.shape), λ(solver.reg[i])/2solver.ρ[i]) # λ is divided by 2 to match the ISTA-type algorithms end # 3. update u diff --git a/src/CGNR.jl b/src/CGNR.jl index 14866b55..ca11da22 100644 --- a/src/CGNR.jl +++ b/src/CGNR.jl @@ -1,8 +1,9 @@ export cgnr, CGNR -mutable struct CGNR{matT,opT,vecT,T,R,PR} <: AbstractKrylovSolver +mutable struct CGNR{matT,opT, N,vecT,T,R,PR} <: AbstractKrylovSolver A::matT AHA::opT + shape::NTuple{N, Int64} L2::R constr::PR x::vecT @@ -49,6 +50,7 @@ function CGNR(A , weights::AbstractVector = similar(AHA, 0) , iterations::Int = 10 , relTol::Real = eps(real(eltype(AHA))) + , shape = (size(AHA, 2),) ) T = eltype(AHA) @@ -82,7 +84,7 @@ function CGNR(A other = identity.(other) - return CGNR(A, AHA, + return CGNR(A, AHA, shape, L2, other, x, x₀, pl, vl, αl, βl, ζl, weights, iterations, relTol, 0.0, normalizeReg) end @@ -134,7 +136,7 @@ performs one CGNR iteration. function iterate(solver::CGNR, iteration::Int=0) if done(solver, iteration) for r in solver.constr - prox!(r, solver.x) + prox!(r, reshape(solver.x, solver.shape)) end return nothing end diff --git a/src/DAXConstrained.jl b/src/DAXConstrained.jl index 8d9f062c..ca8b4fcf 100644 --- a/src/DAXConstrained.jl +++ b/src/DAXConstrained.jl @@ -1,7 +1,8 @@ export DaxConstrained -mutable struct DaxConstrained{matT,T,Tsparse,U} <: AbstractRowActionSolver +mutable struct DaxConstrained{matT,N,T,Tsparse,U} <: AbstractRowActionSolver A::matT + shape::NTuple{N, Int64} u::Vector{T} λ::Float64 B::Tsparse @@ -49,6 +50,7 @@ function DaxConstrained(A , sparseTrafo=nothing , iterations::Int=3 , iterationsInner::Int=2 + , shape = (size(A, 1),) ) T = eltype(A) @@ -79,7 +81,7 @@ function DaxConstrained(A τl = zero(T) αl = zero(T) - return DaxConstrained(A,u,Float64(λ),B,Bnorm²,denom,rowindex,x,bk,bc,xl,yl,yc,δc,εw,τl,αl + return DaxConstrained(A,shape,u,Float64(λ),B,Bnorm²,denom,rowindex,x,bk,bc,xl,yl,yc,δc,εw,τl,αl ,rT.(weights),iterations,iterationsInner) end diff --git a/src/DAXKaczmarz.jl b/src/DAXKaczmarz.jl index 2d080e43..c2f9137a 100644 --- a/src/DAXKaczmarz.jl +++ b/src/DAXKaczmarz.jl @@ -1,7 +1,8 @@ export DaxKaczmarz -mutable struct DaxKaczmarz{matT,T,U} <: AbstractRowActionSolver +mutable struct DaxKaczmarz{matT,N,T,U} <: AbstractRowActionSolver A::matT + shape::NTuple{N, Int64} u::Vector{T} reg::Vector{<:AbstractRegularization} λ::Float64 @@ -51,6 +52,7 @@ function DaxKaczmarz(A , enforcePositive::Bool=false , iterations::Int=3 , iterationsInner::Int=2 + , shape = (size(A, 1),) ) # setup denom and rowindex @@ -80,7 +82,7 @@ function DaxKaczmarz(A if !isempty(reg) && !isnothing(sparseTrafo) reg = map(r -> TransformedRegularization(r, sparseTrafo), reg) end - return DaxKaczmarz(A,u,reg, Float64(λ), denom,rowindex,sumrowweights,x,bk,xl,yl,εw,τl,αl + return DaxKaczmarz(A,shape,u,reg, Float64(λ), denom,rowindex,sumrowweights,x,bk,xl,yl,εw,τl,αl ,T.(weights) ,iterations,iterationsInner) end @@ -103,7 +105,7 @@ end function iterate(solver::DaxKaczmarz, iteration::Int=0) if done(solver,iteration) for r in solver.reg - prox!(r, solver.x) + prox!(r, reshape(solver.x, solver.shape)) end return nothing end diff --git a/src/FISTA.jl b/src/FISTA.jl index f2648c8a..d3f6969f 100644 --- a/src/FISTA.jl +++ b/src/FISTA.jl @@ -1,8 +1,9 @@ export FISTA -mutable struct FISTA{rT <: Real, vecT <: Union{AbstractVector{rT}, AbstractVector{Complex{rT}}}, matA, matAHA, R, RN} <: AbstractProximalGradientSolver +mutable struct FISTA{rT <: Real, vecT <: Union{AbstractVector{rT}, AbstractVector{Complex{rT}}}, matA, N, matAHA, R, RN} <: AbstractProximalGradientSolver A::matA AHA::matAHA + shape::NTuple{N, Int64} reg::R proj::Vector{RN} x::vecT @@ -60,6 +61,7 @@ function FISTA(A , iterations = 50 , restart = :none , verbose = false + , shape = (size(AHA, 2),) ) T = eltype(AHA) @@ -87,7 +89,7 @@ function FISTA(A reg = normalize(FISTA, normalizeReg, reg, A, nothing) - return FISTA(A, AHA, reg[1], other, x, x₀, xᵒˡᵈ, res, rT(rho),rT(theta),rT(theta),iterations,rT(relTol),normalizeReg,one(rT),rT(Inf),verbose,restart) + return FISTA(A, AHA, shape, reg[1], other, x, x₀, xᵒˡᵈ, res, rT(rho),rT(theta),rT(theta),iterations,rT(relTol),normalizeReg,one(rT),rT(Inf),verbose,restart) end """ @@ -146,10 +148,10 @@ function iterate(solver::FISTA, iteration::Int=0) # solver.x .+= solver.ρ .* solver.x₀ # proximal map - prox!(solver.reg, solver.x, solver.ρ * λ(solver.reg)) + prox!(solver.reg, reshape(solver.x, solver.shape), solver.ρ * λ(solver.reg)) for proj in solver.proj - prox!(proj, solver.x) + prox!(proj, reshape(solver.x, solver.shape)) end # gradient restart conditions diff --git a/src/Kaczmarz.jl b/src/Kaczmarz.jl index c4265b62..4e886987 100644 --- a/src/Kaczmarz.jl +++ b/src/Kaczmarz.jl @@ -1,8 +1,9 @@ export kaczmarz export Kaczmarz -mutable struct Kaczmarz{matT,T,U,R,RN} <: AbstractRowActionSolver +mutable struct Kaczmarz{matT,N,T,U,R,RN} <: AbstractRowActionSolver A::matT + shape::NTuple{N, Int64} u::Vector{T} L2::R reg::Vector{RN} @@ -55,6 +56,7 @@ function Kaczmarz(A , seed::Int = 1234 , iterations::Int = 10 , regMatrix = nothing + , shape = (size(A, 1),) ) T = real(eltype(A)) @@ -105,7 +107,7 @@ function Kaczmarz(A τl = zero(eltype(A)) αl = zero(eltype(A)) - return Kaczmarz(A, u, L2, other, denom, rowindex, rowIndexCycle, x, vl, εw, τl, αl, + return Kaczmarz(A, shape, u, L2, other, denom, rowindex, rowIndexCycle, x, vl, εw, τl, αl, T.(w), randomized, subMatrixSize, probabilities, shuffleRows, Int64(seed), iterations, regMatrix, normalizeReg) @@ -167,7 +169,7 @@ function iterate(solver::Kaczmarz, iteration::Int=0) end for r in solver.reg - prox!(r, solver.x) + prox!(r, reshape(solver.x, solver.shape)) end return solver.vl, iteration+1 diff --git a/src/OptISTA.jl b/src/OptISTA.jl index bc916ced..732c81d6 100644 --- a/src/OptISTA.jl +++ b/src/OptISTA.jl @@ -1,8 +1,9 @@ export optista, OptISTA -mutable struct OptISTA{rT <: Real, vecT <: Union{AbstractVector{rT}, AbstractVector{Complex{rT}}}, matA, matAHA, R, RN} <: AbstractProximalGradientSolver +mutable struct OptISTA{rT <: Real, vecT <: Union{AbstractVector{rT}, AbstractVector{Complex{rT}}}, N, matA, matAHA, R, RN} <: AbstractProximalGradientSolver A::matA AHA::matAHA + shape::NTuple{N, Int64} reg::R proj::Vector{RN} x::vecT @@ -65,6 +66,7 @@ function OptISTA(A , relTol = eps(real(eltype(AHA))) , iterations = 50 , verbose = false + , shape = (size(AHA, 2),) ) T = eltype(AHA) @@ -98,7 +100,7 @@ function OptISTA(A other = identity.(other) reg = normalize(OptISTA, normalizeReg, reg, A, nothing) - return OptISTA(A, AHA, reg[1], other, x, x₀, y, z, zᵒˡᵈ, res, rT(rho),rT(theta),rT(theta),rT(θn),rT(0),rT(1),rT(1), + return OptISTA(A, AHA, shape, reg[1], other, x, x₀, y, z, zᵒˡᵈ, res, rT(rho),rT(theta),rT(theta),rT(θn),rT(0),rT(1),rT(1), iterations,rT(relTol),normalizeReg,one(rT),rT(Inf),verbose) end @@ -169,7 +171,7 @@ function iterate(solver::OptISTA, iteration::Int=0) solver.verbose && println("Iteration $iteration; rel. residual = $(solver.rel_res_norm)") # proximal map - prox!(solver.reg, solver.y, solver.ρ * solver.γ * λ(solver.reg)) + prox!(solver.reg, reshape(solver.y, solver.shape), solver.ρ * solver.γ * λ(solver.reg)) # inertia steps # z = x + (y - yᵒˡᵈ) / γ diff --git a/src/POGM.jl b/src/POGM.jl index 86c44d48..a60052a5 100644 --- a/src/POGM.jl +++ b/src/POGM.jl @@ -1,8 +1,9 @@ export pogm, POGM -mutable struct POGM{rT<:Real,vecT<:Union{AbstractVector{rT},AbstractVector{Complex{rT}}},matA,matAHA,R,RN} <: AbstractProximalGradientSolver +mutable struct POGM{rT<:Real,vecT<:Union{AbstractVector{rT},AbstractVector{Complex{rT}}},matA,matAHA,N,R,RN} <: AbstractProximalGradientSolver A::matA AHA::matAHA + shape::NTuple{N, Int64} reg::R proj::Vector{RN} x::vecT @@ -81,6 +82,7 @@ function POGM(A , iterations = 50 , restart = :none , verbose = false + , shape = (size(AHA, 2),) ) T = eltype(AHA) @@ -109,7 +111,7 @@ function POGM(A other = identity.(other) reg = normalize(POGM, normalizeReg, reg, A, nothing) - return POGM(A, AHA, reg[1], other, x, x₀, xᵒˡᵈ, y, z, w, res, rT(rho), rT(theta), rT(theta), rT(0), rT(1), rT(1), rT(1), rT(1), rT(sigma_fac), + return POGM(A, AHA, shape, reg[1], other, x, x₀, xᵒˡᵈ, y, z, w, res, rT(rho), rT(theta), rT(theta), rT(0), rT(1), rT(1), rT(1), rT(1), rT(sigma_fac), iterations, rT(relTol), normalizeReg, one(rT), rT(Inf), verbose, restart) end @@ -192,9 +194,9 @@ function iterate(solver::POGM, iteration::Int=0) solver.z .= solver.x #store this for next iteration and GR # proximal map - prox!(solver.reg, solver.x, solver.γ * λ(solver.reg)) + prox!(solver.reg, reshape(solver.x, solver.shape), solver.γ * λ(solver.reg)) for proj in solver.proj - prox!(proj, solver.x) + prox!(proj, reshape(solver.x, solver.shape)) end # gradient restart conditions diff --git a/src/Regularization/MaskedRegularization.jl b/src/Regularization/MaskedRegularization.jl index dc68d735..0734fd7a 100644 --- a/src/Regularization/MaskedRegularization.jl +++ b/src/Regularization/MaskedRegularization.jl @@ -21,7 +21,7 @@ julia> prox!(masked, fill(-1, 4)) """ struct MaskedRegularization{S, R<:AbstractRegularization} <: AbstractNestedRegularization{S} reg::R - mask::Vector{Bool} + mask::AbstractArray{Bool} MaskedRegularization(reg::R, mask) where R <: AbstractRegularization = new{R, R}(reg, mask) MaskedRegularization(reg::R, mask) where {S, R<:AbstractNestedRegularization{S}} = new{S,R}(reg, mask) end @@ -29,12 +29,12 @@ innerreg(reg::MaskedRegularization) = reg.reg function prox!(reg::MaskedRegularization, x::AbstractArray, args...) - z = view(x, findall(reg.mask)) + z = view(x, reg.mask) prox!(reg.reg, z, args...) return x end function norm(reg::MaskedRegularization, x::AbstractArray, args...) - z = view(x, findall(reg.mask)) + z = view(x, reg.mask) result = norm(reg.reg, z, args...) return result end \ No newline at end of file diff --git a/src/Regularization/NestedRegularization.jl b/src/Regularization/NestedRegularization.jl index 4fe37447..e671e298 100644 --- a/src/Regularization/NestedRegularization.jl +++ b/src/Regularization/NestedRegularization.jl @@ -26,5 +26,5 @@ sinktype(::AbstractNestedRegularization{S}) where S = S prox!(reg::AbstractNestedRegularization{S}, x) where S <: AbstractParameterizedRegularization = prox!(reg, x, λ(reg)) norm(reg::AbstractNestedRegularization{S}, x) where S <: AbstractParameterizedRegularization = norm(reg, x, λ(reg)) -prox!(reg::AbstractNestedRegularization, x, args...) = prox!(innerreg(reg), x, args...) -norm(reg::AbstractNestedRegularization, x, args...) = norm(innerreg(reg), x, args...) \ No newline at end of file +#prox!(reg::AbstractNestedRegularization, x, args...) = prox!(innerreg(reg), x, args...) +#norm(reg::AbstractNestedRegularization, x, args...) = norm(innerreg(reg), x, args...) \ No newline at end of file diff --git a/src/Regularization/PlugAndPlayRegularization.jl b/src/Regularization/PlugAndPlayRegularization.jl index 0c16dd90..0e86143f 100644 --- a/src/Regularization/PlugAndPlayRegularization.jl +++ b/src/Regularization/PlugAndPlayRegularization.jl @@ -11,18 +11,16 @@ The actual regularization term is indirectly defined by the learned proximal map # Keywords * `model` - model applied to the image -* `shape` - dimensions of the image * `input_transform` - transform of image before `model` """ struct PlugAndPlayRegularization{T, M, I} <: AbstractParameterizedRegularization{T} model::M λ::T - shape::Vector{Int} input_transform::I ignoreIm::Bool - PlugAndPlayRegularization(λ::T; model::M, shape, input_transform::I=RegularizedLeastSquares.MinMaxTransform, ignoreIm = false, kargs...) where {T, M, I} = new{T, M, I}(model, λ, shape, input_transform, ignoreIm) + PlugAndPlayRegularization(λ::T; model::M, input_transform::I=RegularizedLeastSquares.MinMaxTransform, ignoreIm = false, kargs...) where {T<:Number, M, I} = new{T, M, I}(model, λ, input_transform, ignoreIm) end -PlugAndPlayRegularization(model, shape; kwargs...) = PlugAndPlayRegularization(one(Float32); kwargs..., model = model, shape = shape) +PlugAndPlayRegularization(model; kwargs...) = PlugAndPlayRegularization(one(Float32); kwargs..., model = model) function prox!(self::PlugAndPlayRegularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Complex{T}} out = real.(x) @@ -43,8 +41,6 @@ function prox!(self::PlugAndPlayRegularization, x::AbstractArray{T}, λ::T) wher end out = copy(x) - out = reshape(out, self.shape...) - tf = self.input_transform(out) out = RegularizedLeastSquares.transform(tf, out) diff --git a/src/Regularization/TransformedRegularization.jl b/src/Regularization/TransformedRegularization.jl index 8bd98d20..8f885bd5 100644 --- a/src/Regularization/TransformedRegularization.jl +++ b/src/Regularization/TransformedRegularization.jl @@ -26,12 +26,25 @@ end innerreg(reg::TransformedRegularization) = reg.reg function prox!(reg::TransformedRegularization, x::AbstractArray, args...) + shape = size(x) + z = reg.trafo * vec(x) + result = prox!(reg.reg, reshape(z, shape), args...) + x[:] = adjoint(reg.trafo) * result + return x +end +function prox!(reg::TransformedRegularization, x::AbstractVector, args...) z = reg.trafo * x result = prox!(reg.reg, z, args...) x[:] = adjoint(reg.trafo) * result return x end function norm(reg::TransformedRegularization, x::AbstractArray, args...) + shape = size(x) + z = reg.trafo * vec(x) + result = norm(reg.reg, reshape(z, shape), args...) + return result +end +function norm(reg::TransformedRegularization, x::AbstractVector, args...) z = reg.trafo * x result = norm(reg.reg, z, args...) return result diff --git a/src/SplitBregman.jl b/src/SplitBregman.jl index 1d7a8091..e5472a58 100644 --- a/src/SplitBregman.jl +++ b/src/SplitBregman.jl @@ -1,8 +1,9 @@ export SplitBregman -mutable struct SplitBregman{matT,opT,R,ropT,P,vecT,rvecT,preconT,rT} <: AbstractPrimalDualSolver +mutable struct SplitBregman{matT,N,opT,R,ropT,P,vecT,rvecT,preconT,rT} <: AbstractPrimalDualSolver # operators and regularization A::matT + shape::NTuple{N, Int64} reg::Vector{R} regTrafo::Vector{ropT} proj::Vector{P} @@ -81,6 +82,7 @@ function SplitBregman(A , relTol::Real = eps(real(eltype(AHA))) , tolInner::Real = 1e-5 , verbose = false + , shape = (size(AHA, 2),) ) T = eltype(AHA) @@ -136,7 +138,7 @@ function SplitBregman(A # normalization parameters reg = normalize(SplitBregman, normalizeReg, reg, A, nothing) - return SplitBregman(A,reg,regTrafo,proj,y,AHA,β,β_y,x,z,zᵒˡᵈ,u,precon,rho,iterations,iterationsInner,iterationsCG,cgStateVars,rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,rT(0),rT(absTol),rT(relTol),rT(tolInner),iter_cnt,normalizeReg,verbose) + return SplitBregman(A, shape, reg,regTrafo,proj,y,AHA,β,β_y,x,z,zᵒˡᵈ,u,precon,rho,iterations,iterationsInner,iterationsCG,cgStateVars,rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,rT(0),rT(absTol),rT(relTol),rT(tolInner),iter_cnt,normalizeReg,verbose) end """ @@ -193,7 +195,7 @@ function iterate(solver::SplitBregman, iteration=1) cg!(solver.x, AHA, solver.β, Pl = solver.precon, maxiter = solver.iterationsCG, reltol = solver.tolInner, statevars = solver.cgStateVars, verbose = solver.verbose) for proj in solver.proj - prox!(proj, solver.x) + prox!(proj, reshape(solver.x, solver.shape)) end # proximal map for regularization terms @@ -207,7 +209,7 @@ function iterate(solver::SplitBregman, iteration=1) mul!(solver.z[i], solver.regTrafo[i], solver.x) solver.z[i] .+= solver.u[i] if solver.ρ[i] != 0 - prox!(solver.reg[i], solver.z[i], λ(solver.reg[i])/2solver.ρ[i]) # λ is divided by 2 to match the ISTA-type algorithms + prox!(solver.reg[i], reshape(solver.z[i], solver.shape), λ(solver.reg[i])/2solver.ρ[i]) # λ is divided by 2 to match the ISTA-type algorithms end # 3. update u diff --git a/src/proximalMaps/ProxLLR.jl b/src/proximalMaps/ProxLLR.jl index f65815ef..702d6612 100644 --- a/src/proximalMaps/ProxLLR.jl +++ b/src/proximalMaps/ProxLLR.jl @@ -15,13 +15,13 @@ Regularization term implementing the proximal map for locally low rank (LLR) reg """ struct LLRRegularization{T, N, TI} <: AbstractParameterizedRegularization{T} where {N, TI<:Integer} λ::T - shape::NTuple{N,TI} + dims::Union{TI} blockSize::NTuple{N,TI} randshift::Bool L::Int64 end -LLRRegularization(λ; shape::NTuple{N,TI}, blockSize::NTuple{N,TI} = ntuple(_ -> 2, N), randshift::Bool = true, L::Int64 = 1, kargs...) where {N,TI<:Integer} = - LLRRegularization(λ, shape, blockSize, randshift, L) +LLRRegularization(λ; dims, blockSize::NTuple{N,TI} = ntuple(_ -> 2, N), randshift::Bool = true, L::Int64 = 1, kargs...) where {N,TI<:Integer} = + LLRRegularization(λ, dims, blockSize, randshift, L) """ prox!(reg::LLRRegularization, x, λ) @@ -29,16 +29,21 @@ LLRRegularization(λ; shape::NTuple{N,TI}, blockSize::NTuple{N,TI} = ntuple(_ - performs the proximal map for LLR regularization using singular-value-thresholding """ function prox!(reg::LLRRegularization{TR, N, TI}, x::AbstractArray{Tc}, λ::T) where {TR, N, TI, T, Tc <: Union{T, Complex{T}}} - shape = reg.shape + dims = reg.dims + otherdims = filter(dim -> dim != dims, 1:ndims(x)) + shape = size(x)[otherdims] blockSize = reg.blockSize randshift = reg.randshift - x = reshape(x, tuple(shape..., length(x) ÷ prod(shape))) - block_idx = CartesianIndices(blockSize) - K = size(x)[end] + K = size(x)[dims] + blocks = zeros(Int64, ndims(x)) + blocks[otherdims] .= blockSize + blocks[dims] = K + block_idx = CartesianIndices(Tuple(blockSize)) if randshift # Random.seed!(1234) + # TODO block_idx was changed shift_idx = (Tuple(rand(block_idx))..., 0) xs = circshift(x, shift_idx) else @@ -48,7 +53,10 @@ function prox!(reg::LLRRegularization{TR, N, TI}, x::AbstractArray{Tc}, λ::T) w ext = mod.(shape, blockSize) pad = mod.(blockSize .- ext, blockSize) if any(pad .!= 0) - xp = zeros(Tc, (shape .+ pad)..., K) + paddedSize = zeros(Int64, ndims(x)) + paddedSize[dims] = K + paddedSize[otherdims] .= shape .+ pad + xp = zeros(Tc, paddedSize...) xp[CartesianIndices(x)] .= xs else xp = xs @@ -59,15 +67,18 @@ function prox!(reg::LLRRegularization{TR, N, TI}, x::AbstractArray{Tc}, λ::T) w BLAS.set_num_threads(1) xᴸᴸᴿ = [Array{Tc}(undef, prod(blockSize), K) for _ = 1:Threads.nthreads()] let xp = xp # Avoid boxing error - @floop for i ∈ CartesianIndices(StepRange.(TI(0), blockSize, shape .- 1)) - @views xᴸᴸᴿ[Threads.threadid()] .= reshape(xp[i.+block_idx, :], :, K) + ranges = fill(StepRange(TI(0), 1, 1), ndims(x)) + ranges[otherdims] .= StepRange.(TI(0), blockSize, shape .- 1) + ranges[dims] = StepRange(1, 1, 1) + @floop for i ∈ CartesianIndices(Tuple(ranges)) + xᴸᴸᴿ[Threads.threadid()] .= reshape(view(xp, i.+block_idx), :, K) # TODO unsure about this reshape if dims != end ub = sqrt(norm(xᴸᴸᴿ[Threads.threadid()]' * xᴸᴸᴿ[Threads.threadid()], Inf)) #upper bound on singular values given by matrix infinity norm if λ >= ub #save time by skipping the SVT as recommended by Ong/Lustig, IEEE 2016 - xp[i.+block_idx, :] .= 0 + xp[i.+block_idx] .= 0 else # threshold singular values SVDec = svd!(xᴸᴸᴿ[Threads.threadid()]) prox!(L1Regularization, SVDec.S, λ) - xp[i.+block_idx, :] .= reshape(SVDec.U * Diagonal(SVDec.S) * SVDec.Vt, blockSize..., :) + xp[i.+block_idx] .= reshape(SVDec.U * Diagonal(SVDec.S) * SVDec.Vt, blockSize..., :) end end end @@ -92,8 +103,8 @@ end returns the value of the LLR-regularization term. """ -function norm(reg::LLRRegularization, x::Vector{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} - shape = reg.shape +function norm(reg::LLRRegularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} + shape = size(x) blockSize = reg.blockSize randshift = reg.randshift L = reg.L diff --git a/src/proximalMaps/ProxNuclear.jl b/src/proximalMaps/ProxNuclear.jl index 0d7fc5dd..16d707aa 100644 --- a/src/proximalMaps/ProxNuclear.jl +++ b/src/proximalMaps/ProxNuclear.jl @@ -14,17 +14,16 @@ Regularization term implementing the proximal map for singular value soft-thresh """ struct NuclearRegularization{T} <: AbstractParameterizedRegularization{T} λ::T - svtShape::NTuple + NuclearRegularization(λ::T; kargs...) where T = new{T}(λ) end -NuclearRegularization(λ; svtShape::NTuple=[], kargs...) = NuclearRegularization(λ, svtShape) """ prox!(reg::NuclearRegularization, x, λ) performs singular value soft-thresholding - i.e. the proximal map for the nuclear norm regularization. """ -function prox!(reg::NuclearRegularization, x::Vector{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} - U,S,V = svd(reshape(x, reg.svtShape)) +function prox!(reg::NuclearRegularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} + U,S,V = svd(x) prox!(L1Regularization, S, λ) x[:] = vec(U*Matrix(Diagonal(S))*V') return x @@ -35,7 +34,7 @@ end returns the value of the nuclear norm regularization term. """ -function norm(reg::NuclearRegularization, x::Vector{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} - U,S,V = svd( reshape(x, reg.svtShape) ) +function norm(reg::NuclearRegularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} + U,S,V = svd(x) return λ*norm(S,1) end diff --git a/src/proximalMaps/ProxProj.jl b/src/proximalMaps/ProxProj.jl index 4a19abca..d7c0ed68 100644 --- a/src/proximalMaps/ProxProj.jl +++ b/src/proximalMaps/ProxProj.jl @@ -5,12 +5,12 @@ struct ProjectionRegularization <: AbstractProjectionRegularization end ProjectionRegularization(; projFunc::Function=x->x, kargs...) = ProjectionRegularization(projFunc) -function prox!(reg::ProjectionRegularization, x::Vector{Tc}) where {T, Tc <: Union{T, Complex{T}}} +function prox!(reg::ProjectionRegularization, x::AbstractArray{Tc}) where {T, Tc <: Union{T, Complex{T}}} x[:] = reg.projFunc(x) return x end -function norm(reg::ProjectionRegularization, x::Vector{Tc}) where {T, Tc <: Union{T, Complex{T}}} +function norm(reg::ProjectionRegularization, x::AbstractArray{Tc}) where {T, Tc <: Union{T, Complex{T}}} y = copy(x) y[:] = prox!(reg, y) if y != x diff --git a/src/proximalMaps/ProxReal.jl b/src/proximalMaps/ProxReal.jl index af898ab5..c083ae32 100644 --- a/src/proximalMaps/ProxReal.jl +++ b/src/proximalMaps/ProxReal.jl @@ -13,7 +13,7 @@ end enforce realness of solution `x`. """ -function prox!(::RealRegularization, x::Vector{T}) where T +function prox!(::RealRegularization, x::AbstractArray{T}) where T enfReal!(x) return x end @@ -23,7 +23,7 @@ end returns the value of the characteristic function of real, Real numbers. """ -function norm(reg::RealRegularization, x::Vector{T}) where T +function norm(reg::RealRegularization, x::AbstractArray{T}) where T y = copy(x) prox!(reg, y) if y != x diff --git a/src/proximalMaps/ProxTV.jl b/src/proximalMaps/ProxTV.jl index 48cf39f1..57411c46 100644 --- a/src/proximalMaps/ProxTV.jl +++ b/src/proximalMaps/ProxTV.jl @@ -21,88 +21,63 @@ and Deblurring Problems", IEEE Trans. Image Process. 18(11), 2009 * `dims` - Dimension to perform the TV along. If `Integer`, the Condat algorithm is called, and the FDG algorithm otherwise. * `iterationsTV=20` - number of FGP iterations """ -struct TVRegularization{T,N,TI} <: AbstractParameterizedRegularization{T} where {N,TI<:Integer} +struct TVRegularization{T, D} <: AbstractParameterizedRegularization{T} where {N, D <: Union{Nothing, Int64, NTuple{N, Int64}}} λ::T - dims - shape::NTuple{N,TI} - iterationsTV::Int64 + dims::D + iterations::Int64 end -TVRegularization(λ; shape=(0,), dims=1:length(shape), iterationsTV=10, kargs...) = TVRegularization(λ, dims, shape, iterationsTV) - - -mutable struct TVParams{Tc,matT} - pq::Vector{Tc} - rs::Vector{Tc} - pqOld::Vector{Tc} - xTmp::Vector{Tc} - ∇::matT -end - -function TVParams(shape, T::Type=Float64; dims=1:length(shape)) - return TVParams(Vector{T}(undef, prod(shape)); shape=shape, dims=dims) -end - -function TVParams(x::AbstractVector{Tc}; shape, dims=1:length(shape)) where {Tc} - ∇ = GradientOp(Tc; shape, dims) - - # allocate storage - xTmp = similar(x) - pq = similar(x, size(∇, 1)) - rs = similar(pq) - pqOld = similar(pq) - - return TVParams(pq, rs, pqOld, xTmp, ∇) -end - - +TVRegularization(λ; dims=nothing, iterations=10, kargs...) = TVRegularization(λ, dims, iterations) """ prox!(reg::TVRegularization, x, λ) Proximal map for TV regularization. Calculated with the Condat algorithm if the TV is calculated only along one dimension and with the Fast Gradient Projection algorithm otherwise. """ -prox!(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T,Tc<:Union{T,Complex{T}}} = proxTV!(x, λ, shape=reg.shape, dims=reg.dims, iterationsTV=reg.iterationsTV) +prox!(reg::TVRegularization, x::AbstractArray{Tc}, λ::T) where {T,Tc<:Union{T,Complex{T}}} = proxTV!(x, λ, reg.dims, iterations=reg.iterations) +prox!(reg::TVRegularization{T, Nothing}, x::AbstractArray{Tc}, λ::T) where {T,Tc<:Union{T,Complex{T}}} = proxTV!(x, λ, 1:ndims(x), iterations=reg.iterations) -function proxTV!(x, λ; shape, dims=1:length(shape), kwargs...) # use kwargs for shape and dims - return proxTV!(x, λ, shape, dims; kwargs...) # define shape and dims w/o kwargs to enable multiple dispatch on dims -end - -function proxTV!(x::AbstractVector{T}, λ::T, shape, dims::Integer; kwargs...) where {T<:Real} - x_ = reshape(x, shape) +function proxTV!(x::AbstractArray{T}, λ::T, dims::Integer; kwargs...) where {T<:Real} + shape = size(x) i = CartesianIndices((ones(Int, dims - 1)..., 0:shape[dims]-1, ones(Int, length(shape) - dims)...)) Threads.@threads for j ∈ CartesianIndices((shape[1:dims-1]..., 1, shape[dims+1:end]...)) - @views @inbounds tv_denoise_1d_condat!(x_[j.+i], shape[dims], λ) + @views @inbounds tv_denoise_1d_condat!(x[j.+i], shape[dims], λ) end return x end -function proxTV!(x::AbstractVector{Tc}, λ::T, shape, dims; iterationsTV=10, tvpar=TVParams(x; shape=shape, dims=dims), kwargs...) where {T<:Real,Tc<:Union{T,Complex{T}}} - return proxTV!(x, λ, tvpar; iterationsTV=iterationsTV) -end +function proxTV!(x::AbstractArray{Tc}, λ::T, dims = 1:ndims(x); iterations=10) where {T<:Real,Tc<:Union{T,Complex{T}}} + shape = size(x) + ∇ = GradientOp(Tc; shape, dims) + + # allocate these in reg term and reuse them? + # allocate storage + xTmp = similar(vec(x)) + pq = similar(xTmp, size(∇, 1)) + rs = similar(pq) + pqOld = similar(pq) -function proxTV!(x::AbstractVector{Tc}, λ::T, p::TVParams{Tc}; iterationsTV=10, kwargs...) where {T<:Real,Tc<:Union{T,Complex{T}}} - @assert length(p.xTmp) == length(x) + @assert length(xTmp) == length(x) # initialize dual variables - p.xTmp .= 0 - p.pq .= 0 - p.rs .= 0 - p.pqOld .= 0 + xTmp .= 0 + pq .= 0 + rs .= 0 + pqOld .= 0 t = one(T) - for _ = 1:iterationsTV - pqTmp = p.pqOld - p.pqOld = p.pq - p.pq = p.rs + for _ = 1:iterations + pqTmp = pqOld + pqOld = pq + pq = rs # gradient projection step for dual variables - Threads.@threads for i ∈ eachindex(p.xTmp, x) - @inbounds p.xTmp[i] = x[i] + Threads.@threads for i ∈ eachindex(xTmp, x) + @inbounds xTmp[i] = x[i] end - mul!(p.xTmp, transpose(p.∇), p.rs, -λ, 1) # xtmp = x-λ*transpose(∇)*rs - mul!(p.pq, p.∇, p.xTmp, 1 / (8λ), 1) # rs = ∇*xTmp/(8λ) + mul!(xTmp, transpose(∇), rs, -λ, 1) # xtmp = x-λ*transpose(∇)*rs + mul!(pq, ∇, xTmp, 1 / (8λ), 1) # rs = ∇*xTmp/(8λ) - restrictMagnitude!(p.pq) + restrictMagnitude!(pq) # form linear combination of old and new estimates tOld = t @@ -110,13 +85,13 @@ function proxTV!(x::AbstractVector{Tc}, λ::T, p::TVParams{Tc}; iterationsTV=10, t2 = ((tOld - 1) / t) t3 = 1 + t2 - p.rs = pqTmp - Threads.@threads for i ∈ eachindex(p.rs, p.pq, p.pqOld) - @inbounds p.rs[i] = t3 * p.pq[i] - t2 * p.pqOld[i] + rs = pqTmp + Threads.@threads for i ∈ eachindex(rs, pq, pqOld) + @inbounds rs[i] = t3 * pq[i] - t2 * pqOld[i] end end - mul!(x, transpose(p.∇), p.pq, -λ, one(Tc)) # x .-= λ*transpose(∇)*pq + mul!(vec(x), transpose(∇), pq, -λ, one(Tc)) # x .-= λ*transpose(∇)*pq return x end @@ -132,7 +107,7 @@ end returns the value of the TV-regularization term. """ -function norm(reg::TVRegularization, x::Vector{Tc}, λ::T) where {T<:Real,Tc<:Union{T,Complex{T}}} - ∇ = GradientOp(Tc; shape=reg.shape, dims=reg.dims) - return λ * norm(∇ * x, 1) +function norm(reg::TVRegularization, x::AbstractArray{Tc}, λ::T) where {T<:Real,Tc<:Union{T,Complex{T}}} + ∇ = GradientOp(Tc; shape=size(x), dims= isnothing(reg.dims) ? UnitRange(1, ndims(x)) : reg.dims) + return λ * norm(∇ * vec(x), 1) end