Skip to content

Add shape keyword to solvers #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/ADMM.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
export ADMM

mutable struct ADMM{matT,opT,R,ropT,P,vecT,rvecT,preconT,rT} <: AbstractPrimalDualSolver where {vecT <: AbstractVector{Union{rT, Complex{rT}}}, rvecT <: AbstractVector{rT}}
mutable struct ADMM{matT,N,opT,R,ropT,P,vecT,rvecT,preconT,rT} <: AbstractPrimalDualSolver where {vecT <: AbstractVector{Union{rT, Complex{rT}}}, rvecT <: AbstractVector{rT}}
# operators and regularization
A::matT
shape::NTuple{N, Int64}
reg::Vector{R}
regTrafo::Vector{ropT}
proj::Vector{P}
Expand Down Expand Up @@ -81,6 +82,7 @@ function ADMM(A
, relTol::Real = eps(real(eltype(AHA)))
, tolInner::Real = 1e-5
, verbose = false
, shape = (size(AHA, 2),)
)

T = eltype(AHA)
Expand Down Expand Up @@ -135,7 +137,7 @@ function ADMM(A
# normalization parameters
reg = normalize(ADMM, normalizeReg, reg, A, nothing)

return ADMM(A,reg,regTrafo,proj,AHA,β,β_y,x,xᵒˡᵈ,z,zᵒˡᵈ,u,uᵒˡᵈ,precon,rho,iterations
return ADMM(A,shape,reg,regTrafo,proj,AHA,β,β_y,x,xᵒˡᵈ,z,zᵒˡᵈ,u,uᵒˡᵈ,precon,rho,iterations
,iterationsCG,cgStateVars,rᵏ,sᵏ,ɛᵖʳⁱ,ɛᵈᵘᵃ,rT(0),Δ,rT(absTol),rT(relTol),rT(tolInner),normalizeReg,vary_rho,verbose)
end

Expand Down Expand Up @@ -198,7 +200,7 @@ function iterate(solver::ADMM, iteration=1)
cg!(solver.x, AHA, solver.β, Pl = solver.precon, maxiter = solver.iterationsCG, reltol = solver.tolInner, statevars = solver.cgStateVars, verbose = solver.verbose)

for proj in solver.proj
prox!(proj, solver.x)
prox!(proj, reshape(solver.x, solver.shape))
end

# proximal map for regularization terms
Expand All @@ -212,7 +214,7 @@ function iterate(solver::ADMM, iteration=1)
mul!(solver.z[i], solver.regTrafo[i], solver.x)
solver.z[i] .+= solver.u[i]
if solver.ρ[i] != 0
prox!(solver.reg[i], solver.z[i], λ(solver.reg[i])/2solver.ρ[i]) # λ is divided by 2 to match the ISTA-type algorithms
prox!(solver.reg[i], reshape(solver.z[i], solver.shape), λ(solver.reg[i])/2solver.ρ[i]) # λ is divided by 2 to match the ISTA-type algorithms
end

# 3. update u
Expand Down
8 changes: 5 additions & 3 deletions src/CGNR.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
export cgnr, CGNR

mutable struct CGNR{matT,opT,vecT,T,R,PR} <: AbstractKrylovSolver
mutable struct CGNR{matT,opT, N,vecT,T,R,PR} <: AbstractKrylovSolver
A::matT
AHA::opT
shape::NTuple{N, Int64}
L2::R
constr::PR
x::vecT
Expand Down Expand Up @@ -49,6 +50,7 @@ function CGNR(A
, weights::AbstractVector = similar(AHA, 0)
, iterations::Int = 10
, relTol::Real = eps(real(eltype(AHA)))
, shape = (size(AHA, 2),)
)

T = eltype(AHA)
Expand Down Expand Up @@ -82,7 +84,7 @@ function CGNR(A
other = identity.(other)


return CGNR(A, AHA,
return CGNR(A, AHA, shape,
L2, other, x, x₀, pl, vl, αl, βl, ζl, weights, iterations, relTol, 0.0, normalizeReg)
end

Expand Down Expand Up @@ -134,7 +136,7 @@ performs one CGNR iteration.
function iterate(solver::CGNR, iteration::Int=0)
if done(solver, iteration)
for r in solver.constr
prox!(r, solver.x)
prox!(r, reshape(solver.x, solver.shape))
end
return nothing
end
Expand Down
6 changes: 4 additions & 2 deletions src/DAXConstrained.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
export DaxConstrained

mutable struct DaxConstrained{matT,T,Tsparse,U} <: AbstractRowActionSolver
mutable struct DaxConstrained{matT,N,T,Tsparse,U} <: AbstractRowActionSolver
A::matT
shape::NTuple{N, Int64}
u::Vector{T}
λ::Float64
B::Tsparse
Expand Down Expand Up @@ -49,6 +50,7 @@ function DaxConstrained(A
, sparseTrafo=nothing
, iterations::Int=3
, iterationsInner::Int=2
, shape = (size(A, 1),)
)

T = eltype(A)
Expand Down Expand Up @@ -79,7 +81,7 @@ function DaxConstrained(A
τl = zero(T)
αl = zero(T)

return DaxConstrained(A,u,Float64(λ),B,Bnorm²,denom,rowindex,x,bk,bc,xl,yl,yc,δc,εw,τl,αl
return DaxConstrained(A,shape,u,Float64(λ),B,Bnorm²,denom,rowindex,x,bk,bc,xl,yl,yc,δc,εw,τl,αl
,rT.(weights),iterations,iterationsInner)
end

Expand Down
8 changes: 5 additions & 3 deletions src/DAXKaczmarz.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
export DaxKaczmarz

mutable struct DaxKaczmarz{matT,T,U} <: AbstractRowActionSolver
mutable struct DaxKaczmarz{matT,N,T,U} <: AbstractRowActionSolver
A::matT
shape::NTuple{N, Int64}
u::Vector{T}
reg::Vector{<:AbstractRegularization}
λ::Float64
Expand Down Expand Up @@ -51,6 +52,7 @@ function DaxKaczmarz(A
, enforcePositive::Bool=false
, iterations::Int=3
, iterationsInner::Int=2
, shape = (size(A, 1),)
)

# setup denom and rowindex
Expand Down Expand Up @@ -80,7 +82,7 @@ function DaxKaczmarz(A
if !isempty(reg) && !isnothing(sparseTrafo)
reg = map(r -> TransformedRegularization(r, sparseTrafo), reg)
end
return DaxKaczmarz(A,u,reg, Float64(λ), denom,rowindex,sumrowweights,x,bk,xl,yl,εw,τl,αl
return DaxKaczmarz(A,shape,u,reg, Float64(λ), denom,rowindex,sumrowweights,x,bk,xl,yl,εw,τl,αl
,T.(weights) ,iterations,iterationsInner)
end

Expand All @@ -103,7 +105,7 @@ end
function iterate(solver::DaxKaczmarz, iteration::Int=0)
if done(solver,iteration)
for r in solver.reg
prox!(r, solver.x)
prox!(r, reshape(solver.x, solver.shape))
end
return nothing
end
Expand Down
10 changes: 6 additions & 4 deletions src/FISTA.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
export FISTA

mutable struct FISTA{rT <: Real, vecT <: Union{AbstractVector{rT}, AbstractVector{Complex{rT}}}, matA, matAHA, R, RN} <: AbstractProximalGradientSolver
mutable struct FISTA{rT <: Real, vecT <: Union{AbstractVector{rT}, AbstractVector{Complex{rT}}}, matA, N, matAHA, R, RN} <: AbstractProximalGradientSolver
A::matA
AHA::matAHA
shape::NTuple{N, Int64}
reg::R
proj::Vector{RN}
x::vecT
Expand Down Expand Up @@ -60,6 +61,7 @@ function FISTA(A
, iterations = 50
, restart = :none
, verbose = false
, shape = (size(AHA, 2),)
)

T = eltype(AHA)
Expand Down Expand Up @@ -87,7 +89,7 @@ function FISTA(A
reg = normalize(FISTA, normalizeReg, reg, A, nothing)


return FISTA(A, AHA, reg[1], other, x, x₀, xᵒˡᵈ, res, rT(rho),rT(theta),rT(theta),iterations,rT(relTol),normalizeReg,one(rT),rT(Inf),verbose,restart)
return FISTA(A, AHA, shape, reg[1], other, x, x₀, xᵒˡᵈ, res, rT(rho),rT(theta),rT(theta),iterations,rT(relTol),normalizeReg,one(rT),rT(Inf),verbose,restart)
end

"""
Expand Down Expand Up @@ -146,10 +148,10 @@ function iterate(solver::FISTA, iteration::Int=0)
# solver.x .+= solver.ρ .* solver.x₀

# proximal map
prox!(solver.reg, solver.x, solver.ρ * λ(solver.reg))
prox!(solver.reg, reshape(solver.x, solver.shape), solver.ρ * λ(solver.reg))

for proj in solver.proj
prox!(proj, solver.x)
prox!(proj, reshape(solver.x, solver.shape))
end

# gradient restart conditions
Expand Down
8 changes: 5 additions & 3 deletions src/Kaczmarz.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
export kaczmarz
export Kaczmarz

mutable struct Kaczmarz{matT,T,U,R,RN} <: AbstractRowActionSolver
mutable struct Kaczmarz{matT,N,T,U,R,RN} <: AbstractRowActionSolver
A::matT
shape::NTuple{N, Int64}
u::Vector{T}
L2::R
reg::Vector{RN}
Expand Down Expand Up @@ -55,6 +56,7 @@ function Kaczmarz(A
, seed::Int = 1234
, iterations::Int = 10
, regMatrix = nothing
, shape = (size(A, 1),)
)

T = real(eltype(A))
Expand Down Expand Up @@ -105,7 +107,7 @@ function Kaczmarz(A
τl = zero(eltype(A))
αl = zero(eltype(A))

return Kaczmarz(A, u, L2, other, denom, rowindex, rowIndexCycle, x, vl, εw, τl, αl,
return Kaczmarz(A, shape, u, L2, other, denom, rowindex, rowIndexCycle, x, vl, εw, τl, αl,
T.(w), randomized, subMatrixSize, probabilities, shuffleRows,
Int64(seed), iterations, regMatrix,
normalizeReg)
Expand Down Expand Up @@ -167,7 +169,7 @@ function iterate(solver::Kaczmarz, iteration::Int=0)
end

for r in solver.reg
prox!(r, solver.x)
prox!(r, reshape(solver.x, solver.shape))
end

return solver.vl, iteration+1
Expand Down
8 changes: 5 additions & 3 deletions src/OptISTA.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
export optista, OptISTA

mutable struct OptISTA{rT <: Real, vecT <: Union{AbstractVector{rT}, AbstractVector{Complex{rT}}}, matA, matAHA, R, RN} <: AbstractProximalGradientSolver
mutable struct OptISTA{rT <: Real, vecT <: Union{AbstractVector{rT}, AbstractVector{Complex{rT}}}, N, matA, matAHA, R, RN} <: AbstractProximalGradientSolver
A::matA
AHA::matAHA
shape::NTuple{N, Int64}
reg::R
proj::Vector{RN}
x::vecT
Expand Down Expand Up @@ -65,6 +66,7 @@ function OptISTA(A
, relTol = eps(real(eltype(AHA)))
, iterations = 50
, verbose = false
, shape = (size(AHA, 2),)
)

T = eltype(AHA)
Expand Down Expand Up @@ -98,7 +100,7 @@ function OptISTA(A
other = identity.(other)
reg = normalize(OptISTA, normalizeReg, reg, A, nothing)

return OptISTA(A, AHA, reg[1], other, x, x₀, y, z, zᵒˡᵈ, res, rT(rho),rT(theta),rT(theta),rT(θn),rT(0),rT(1),rT(1),
return OptISTA(A, AHA, shape, reg[1], other, x, x₀, y, z, zᵒˡᵈ, res, rT(rho),rT(theta),rT(theta),rT(θn),rT(0),rT(1),rT(1),
iterations,rT(relTol),normalizeReg,one(rT),rT(Inf),verbose)
end

Expand Down Expand Up @@ -169,7 +171,7 @@ function iterate(solver::OptISTA, iteration::Int=0)
solver.verbose && println("Iteration $iteration; rel. residual = $(solver.rel_res_norm)")

# proximal map
prox!(solver.reg, solver.y, solver.ρ * solver.γ * λ(solver.reg))
prox!(solver.reg, reshape(solver.y, solver.shape), solver.ρ * solver.γ * λ(solver.reg))

# inertia steps
# z = x + (y - yᵒˡᵈ) / γ
Expand Down
10 changes: 6 additions & 4 deletions src/POGM.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
export pogm, POGM

mutable struct POGM{rT<:Real,vecT<:Union{AbstractVector{rT},AbstractVector{Complex{rT}}},matA,matAHA,R,RN} <: AbstractProximalGradientSolver
mutable struct POGM{rT<:Real,vecT<:Union{AbstractVector{rT},AbstractVector{Complex{rT}}},matA,matAHA,N,R,RN} <: AbstractProximalGradientSolver
A::matA
AHA::matAHA
shape::NTuple{N, Int64}
reg::R
proj::Vector{RN}
x::vecT
Expand Down Expand Up @@ -81,6 +82,7 @@ function POGM(A
, iterations = 50
, restart = :none
, verbose = false
, shape = (size(AHA, 2),)
)

T = eltype(AHA)
Expand Down Expand Up @@ -109,7 +111,7 @@ function POGM(A
other = identity.(other)
reg = normalize(POGM, normalizeReg, reg, A, nothing)

return POGM(A, AHA, reg[1], other, x, x₀, xᵒˡᵈ, y, z, w, res, rT(rho), rT(theta), rT(theta), rT(0), rT(1), rT(1), rT(1), rT(1), rT(sigma_fac),
return POGM(A, AHA, shape, reg[1], other, x, x₀, xᵒˡᵈ, y, z, w, res, rT(rho), rT(theta), rT(theta), rT(0), rT(1), rT(1), rT(1), rT(1), rT(sigma_fac),
iterations, rT(relTol), normalizeReg, one(rT), rT(Inf), verbose, restart)
end

Expand Down Expand Up @@ -192,9 +194,9 @@ function iterate(solver::POGM, iteration::Int=0)
solver.z .= solver.x #store this for next iteration and GR

# proximal map
prox!(solver.reg, solver.x, solver.γ * λ(solver.reg))
prox!(solver.reg, reshape(solver.x, solver.shape), solver.γ * λ(solver.reg))
for proj in solver.proj
prox!(proj, solver.x)
prox!(proj, reshape(solver.x, solver.shape))
end

# gradient restart conditions
Expand Down
6 changes: 3 additions & 3 deletions src/Regularization/MaskedRegularization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@ julia> prox!(masked, fill(-1, 4))
"""
struct MaskedRegularization{S, R<:AbstractRegularization} <: AbstractNestedRegularization{S}
reg::R
mask::Vector{Bool}
mask::AbstractArray{Bool}
MaskedRegularization(reg::R, mask) where R <: AbstractRegularization = new{R, R}(reg, mask)
MaskedRegularization(reg::R, mask) where {S, R<:AbstractNestedRegularization{S}} = new{S,R}(reg, mask)
end
innerreg(reg::MaskedRegularization) = reg.reg


function prox!(reg::MaskedRegularization, x::AbstractArray, args...)
z = view(x, findall(reg.mask))
z = view(x, reg.mask)
prox!(reg.reg, z, args...)
return x
end
function norm(reg::MaskedRegularization, x::AbstractArray, args...)
z = view(x, findall(reg.mask))
z = view(x, reg.mask)
result = norm(reg.reg, z, args...)
return result
end
4 changes: 2 additions & 2 deletions src/Regularization/NestedRegularization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ sinktype(::AbstractNestedRegularization{S}) where S = S
prox!(reg::AbstractNestedRegularization{S}, x) where S <: AbstractParameterizedRegularization = prox!(reg, x, λ(reg))
norm(reg::AbstractNestedRegularization{S}, x) where S <: AbstractParameterizedRegularization = norm(reg, x, λ(reg))

prox!(reg::AbstractNestedRegularization, x, args...) = prox!(innerreg(reg), x, args...)
norm(reg::AbstractNestedRegularization, x, args...) = norm(innerreg(reg), x, args...)
#prox!(reg::AbstractNestedRegularization, x, args...) = prox!(innerreg(reg), x, args...)
#norm(reg::AbstractNestedRegularization, x, args...) = norm(innerreg(reg), x, args...)
8 changes: 2 additions & 6 deletions src/Regularization/PlugAndPlayRegularization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@ The actual regularization term is indirectly defined by the learned proximal map

# Keywords
* `model` - model applied to the image
* `shape` - dimensions of the image
* `input_transform` - transform of image before `model`
"""
struct PlugAndPlayRegularization{T, M, I} <: AbstractParameterizedRegularization{T}
model::M
λ::T
shape::Vector{Int}
input_transform::I
ignoreIm::Bool
PlugAndPlayRegularization(λ::T; model::M, shape, input_transform::I=RegularizedLeastSquares.MinMaxTransform, ignoreIm = false, kargs...) where {T, M, I} = new{T, M, I}(model, λ, shape, input_transform, ignoreIm)
PlugAndPlayRegularization(λ::T; model::M, input_transform::I=RegularizedLeastSquares.MinMaxTransform, ignoreIm = false, kargs...) where {T<:Number, M, I} = new{T, M, I}(model, λ, input_transform, ignoreIm)
end
PlugAndPlayRegularization(model, shape; kwargs...) = PlugAndPlayRegularization(one(Float32); kwargs..., model = model, shape = shape)
PlugAndPlayRegularization(model; kwargs...) = PlugAndPlayRegularization(one(Float32); kwargs..., model = model)

function prox!(self::PlugAndPlayRegularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Complex{T}}
out = real.(x)
Expand All @@ -43,8 +41,6 @@ function prox!(self::PlugAndPlayRegularization, x::AbstractArray{T}, λ::T) wher
end

out = copy(x)
out = reshape(out, self.shape...)

tf = self.input_transform(out)

out = RegularizedLeastSquares.transform(tf, out)
Expand Down
13 changes: 13 additions & 0 deletions src/Regularization/TransformedRegularization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,25 @@ end
innerreg(reg::TransformedRegularization) = reg.reg

function prox!(reg::TransformedRegularization, x::AbstractArray, args...)
shape = size(x)
z = reg.trafo * vec(x)
result = prox!(reg.reg, reshape(z, shape), args...)
x[:] = adjoint(reg.trafo) * result
return x
end
function prox!(reg::TransformedRegularization, x::AbstractVector, args...)
z = reg.trafo * x
result = prox!(reg.reg, z, args...)
x[:] = adjoint(reg.trafo) * result
return x
end
function norm(reg::TransformedRegularization, x::AbstractArray, args...)
shape = size(x)
z = reg.trafo * vec(x)
result = norm(reg.reg, reshape(z, shape), args...)
return result
end
function norm(reg::TransformedRegularization, x::AbstractVector, args...)
z = reg.trafo * x
result = norm(reg.reg, z, args...)
return result
Expand Down
Loading