Skip to content

Differentiate between L2 and squared L2 proximal maps #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/API/regularization.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ REPL one can access this documentation by entering the help mode with `?`

```@docs
RegularizedLeastSquares.L1Regularization
RegularizedLeastSquares.L2Regularization
RegularizedLeastSquares.SqrL2Regularization
RegularizedLeastSquares.L21Regularization
RegularizedLeastSquares.LLRRegularization
RegularizedLeastSquares.NuclearRegularization
Expand Down
4 changes: 2 additions & 2 deletions docs/src/regularization.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ This group of regularization terms features a regularization parameter `λ` that
These terms are constructed by supplying a `λ` and optionally term specific keyword arguments:

```jldoctest l2
julia> l2 = L2Regularization(0.3)
L2Regularization{Float64}(0.3)
julia> l2 = SqrL2Regularization(0.3)
SqrL2Regularization{Float64}(0.3)
```
Parameterized regularization terms implement:
```julia
Expand Down
12 changes: 6 additions & 6 deletions src/Direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ mutable struct DirectSolver{matT,vecT, R, PR} <: AbstractDirectSolver
proj::Vector{PR}
end

function DirectSolver(A; reg::Vector{<:AbstractRegularization} = [L2Regularization(zero(real(eltype(A))))], normalizeReg::AbstractRegularizationNormalization = NoNormalization())
function DirectSolver(A; reg::Vector{<:AbstractRegularization} = [SqrL2Regularization(zero(real(eltype(A))))], normalizeReg::AbstractRegularizationNormalization = NoNormalization())
reg = normalize(DirectSolver, normalizeReg, reg, A, nothing)
idx = findsink(L2Regularization, reg)
idx = findsink(SqrL2Regularization, reg)
if isnothing(idx)
L2 = L2Regularization(zero(T))
L2 = SqrL2Regularization(zero(T))
else
L2 = reg[idx]
deleteat!(reg, idx)
Expand Down Expand Up @@ -98,11 +98,11 @@ mutable struct PseudoInverse{R, vecT, PR} <: AbstractDirectSolver
proj::Vector{PR}
end

function PseudoInverse(A; reg::Vector{<:AbstractRegularization} = [L2Regularization(zero(real(eltype(A))))], normalizeReg::AbstractRegularizationNormalization = NoNormalization())
function PseudoInverse(A; reg::Vector{<:AbstractRegularization} = [SqrL2Regularization(zero(real(eltype(A))))], normalizeReg::AbstractRegularizationNormalization = NoNormalization())
reg = normalize(PseudoInverse, normalizeReg, reg, A, nothing)
idx = findsink(L2Regularization, reg)
idx = findsink(SqrL2Regularization, reg)
if isnothing(idx)
L2 = L2Regularization(zero(T))
L2 = SqrL2Regularization(zero(T))
else
L2 = reg[idx]
deleteat!(reg, idx)
Expand Down
8 changes: 4 additions & 4 deletions src/Kaczmarz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ mutable struct Kaczmarz{matT,T,U,R,RN} <: AbstractRowActionSolver
end

"""
Kaczmarz(A; reg = L2Regularization(0), normalizeReg = NoNormalization(), weights=nothing, randomized=false, subMatrixFraction=0.15, shuffleRows=false, seed=1234, iterations=10, regMatrix=nothing)
Kaczmarz(A; reg = SqrL2Regularization(0), normalizeReg = NoNormalization(), weights=nothing, randomized=false, subMatrixFraction=0.15, shuffleRows=false, seed=1234, iterations=10, regMatrix=nothing)

Creates a Kaczmarz object for the forward operator `A`.

Expand All @@ -46,7 +46,7 @@ Creates a Kaczmarz object for the forward operator `A`.
See also [`createLinearSolver`](@ref), [`solve!`](@ref).
"""
function Kaczmarz(A
; reg = L2Regularization(0)
; reg = SqrL2Regularization(0)
, normalizeReg::AbstractRegularizationNormalization = NoNormalization()
, weights = nothing
, randomized::Bool = false
Expand All @@ -68,9 +68,9 @@ function Kaczmarz(A
# Prepare regularization terms
reg = isa(reg, AbstractVector) ? reg : [reg]
reg = normalize(Kaczmarz, normalizeReg, reg, A, nothing)
idx = findsink(L2Regularization, reg)
idx = findsink(SqrL2Regularization, reg)
if isnothing(idx)
L2 = L2Regularization(zero(T))
L2 = SqrL2Regularization(zero(T))
else
L2 = reg[idx]
deleteat!(reg, idx)
Expand Down
2 changes: 1 addition & 1 deletion src/RegularizedLeastSquares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ isapplicable(::Type{T}, reg::Vector{<:AbstractRegularization}) where T <: Abstra
function isapplicable(::Type{T}, reg::Vector{<:AbstractRegularization}) where T <: AbstractRowActionSolver
applicable = true
applicable &= length(findsinks(AbstractParameterizedRegularization, reg)) <= 2
applicable &= length(findsinks(L2Regularization, reg)) == 1
applicable &= length(findsinks(SqrL2Regularization, reg)) == 1
return applicable
end

Expand Down
11 changes: 6 additions & 5 deletions src/proximalMaps/ProxL2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ export L2Regularization
"""
L2Regularization

Regularization term implementing the proximal map for Tikhonov regularization.
Regularization term implementing the proximal map for `l_2` regularization.
"""
struct L2Regularization{T} <: AbstractParameterizedRegularization{T}
λ::T
Expand All @@ -12,11 +12,12 @@ end

"""
prox!(reg::L2Regularization, x, λ)

performs the proximal map for Tikhonov regularization.
Tikhonov
performs the proximal map for `l_2` regularization.
"""
function prox!(::L2Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
x[:] .*= 1. / (1. + 2. *λ)#*x
scale = max(0, 1 - λ / norm(x))
x[:] .*= scale
return x
end

Expand All @@ -25,4 +26,4 @@ end

returns the value of the L2-regularization term
"""
norm(::L2Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} = λ*norm(x,2)^2
norm(::L2Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} = λ*norm(x,2)
28 changes: 28 additions & 0 deletions src/proximalMaps/ProxSqrL2.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
export SqrL2Regularization

"""
SqrL2Regularization

Regularization term implementing the proximal map for Tikhonov regularization.
"""
struct SqrL2Regularization{T} <: AbstractParameterizedRegularization{T}
λ::T
SqrL2Regularization(λ::T; kargs...) where T = new{T}(λ)
end

"""
prox!(reg::SqrL2Regularization, x, λ)

performs the proximal map for Tikhonov regularization.
"""
function prox!(::SqrL2Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}}
x[:] .*= 1. / (1. + 2. *λ)#*x
return x
end

"""
norm(reg::SqrL2Regularization, x, λ)

returns the value of the L2-regularization term
"""
norm(::SqrL2Regularization, x::AbstractArray{Tc}, λ::T) where {T, Tc <: Union{T, Complex{T}}} = λ*norm(x,2)^2
1 change: 1 addition & 0 deletions src/proximalMaps/ProximalMaps.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include("ProxL1.jl")
include("ProxL2.jl")
include("ProxSqrL2.jl")
include("ProxL21.jl")
include("ProxLLR.jl")
# includes/ProxSLR.jl")
Expand Down