forked from MilesCranmer/DataDrivenDiffEq.jl
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathSR3.jl
More file actions
122 lines (92 loc) · 3.24 KB
/
SR3.jl
File metadata and controls
122 lines (92 loc) · 3.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#Based upon alg 2 in https://ieeexplore.ieee.org/document/8573778
"""
$(TYPEDEF)
`SR3` is an optimizer framework introduced [by Zheng et al., 2018](https://ieeexplore.ieee.org/document/8573778) and used within
[Champion et al., 2019](https://arxiv.org/abs/1906.10612). `SR3` contains a sparsification parameter `λ`, a relaxation `ν`.
It solves the following problem
```math
\\argmin_{x, w} \\frac{1}{2} \\| Ax-b\\|_2 + \\lambda R(w) + \\frac{\\nu}{2}\\|x-w\\|_2
```
Where `R` is a proximal operator, and the result is given by `w`.
# Fields
$(FIELDS)
# Example
```julia
opt = SR3()
opt = SR3(1e-2)
opt = SR3(1e-3, 1.0)
opt = SR3(1e-3, 1.0, SoftThreshold())
```
## Note
Opposed to the original formulation, we use `nu` as a relaxation parameter,
as given in [Champion et al., 2019](https://arxiv.org/abs/1906.10612). In the standard case of
hard thresholding the sparsity is interpreted as `λ = threshold^2 / 2`, otherwise `λ = threshold`.
"""
mutable struct SR3{T, V, P <: AbstractProximalOperator} <: AbstractSparseRegressionAlgorithm
"""Sparsity threshold"""
thresholds::T
"""Relaxation parameter"""
nu::V
"""Proximal operator"""
proximal::P
function SR3(threshold::T = 1e-1, nu::V = 1.0,
R::P = HardThreshold()) where {T, V <: Number,
P <: AbstractProximalOperator}
@assert all(threshold .> zero(eltype(threshold))) "Threshold must be positive definite"
@assert nu>zero(V) "Relaxation must be positive definite"
λ = isa(R, HardThreshold) ? threshold .^ 2 / 2 : threshold
return new{typeof(λ), V, P}(λ, nu, R)
end
function SR3(threshold::T, R::P) where {T, P <: AbstractProximalOperator}
@assert all(threshold .> zero(eltype(threshold))) "Threshold must be positive definite"
λ = isa(R, HardThreshold) ? threshold .^ 2 / 2 : threshold
ν = one(eltype(λ))
return new{typeof(λ), eltype(λ), P}(λ, ν, R)
end
end
Base.summary(::SR3) = "SR3"
struct SR3Cache{C, A, P <: AbstractProximalOperator, AT, BT, T, ATT, BTT} <:
AbstractSparseRegressionCache
X::C
X_prev::C
active_set::A
proximal::P
#
W::C
#
A::AT
B::BT
nu::T
# Original Data
Ã::ATT
B̃::BTT
end
function init_cache(alg::SR3, A::AbstractMatrix, b::AbstractVector)
init_cache(alg, A, permutedims(b))
end
function init_cache(alg::SR3, A::AbstractMatrix, B::AbstractMatrix)
n_x, m_x = size(A)
@assert size(B, 1)==1 "Caches only hold single targets!"
λ = minimum(get_thresholds(alg))
@unpack nu, proximal = alg
# Init matrices
X = cholesky(A * A' .+ I(n_x) * nu)
Y = B * A'
coefficients = Y / X
idx = BitArray(undef, size(coefficients)...)
active_set!(idx, proximal, coefficients, λ)
return SR3Cache{typeof(coefficients), typeof(idx), typeof(proximal), typeof(X),
typeof(Y), typeof(nu), typeof(A), typeof(B)}(coefficients,
copy(coefficients), idx,
proximal,
zero(coefficients),
X, Y, nu, A, B)
end
function step!(cache::SR3Cache, λ::T) where {T <: Number}
@unpack X, X_prev, active_set, proximal, A, B, W, nu = cache
X_prev .= X
W .= (B .+ X * nu) / A
proximal(X, W, λ)
active_set!(active_set, proximal, X, λ)
return
end