Skip to content

Commit ba054b7

Browse files
committed
refactor: move LM to First Order
1 parent 196adbb commit ba054b7

File tree

6 files changed

+297
-269
lines changed

6 files changed

+297
-269
lines changed

lib/NonlinearSolveBase/src/descent/geodesic_acceleration.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ function InternalAPI.solve!(
126126

127127
if 2 * norm_a norm_v * cache.α
128128
@bb @. δu = v + a / 2
129-
SciMLBase.set_du!(cache, δu, idx)
129+
set_du!(cache, δu, idx)
130130
cache.last_step_accepted = true
131131
else
132132
cache.last_step_accepted = false

lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@ using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,
2020
AbstractApproximateJacobianUpdateRule, AbstractDescentDirection,
2121
AbstractApproximateJacobianUpdateRuleCache,
2222
AbstractDampingFunction, AbstractDampingFunctionCache,
23+
AbstractTrustRegionMethod, AbstractTrustRegionMethodCache,
2324
Utils, InternalAPI, get_timer_output, @static_timeit,
2425
update_trace!, L2_NORM,
2526
NewtonDescent, DampedNewtonDescent
2627
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode
2728
using SciMLOperators: AbstractSciMLOperator
2829
using Setfield: @set!
29-
using StaticArraysCore: StaticArray, Size, MArray
30+
using StaticArraysCore: StaticArray, SArray, Size, MArray
3031

3132
include("raphson.jl")
3233
include("gauss_newton.jl")
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,293 @@
1+
"""
2+
LevenbergMarquardt(;
3+
linsolve = nothing, precs = nothing,
4+
damping_initial::Real = 1.0, α_geodesic::Real = 0.75, disable_geodesic = Val(false),
5+
damping_increase_factor::Real = 2.0, damping_decrease_factor::Real = 3.0,
6+
finite_diff_step_geodesic = 0.1, b_uphill::Real = 1.0, min_damping_D::Real = 1e-8,
7+
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing
8+
)
19
10+
An advanced Levenberg-Marquardt implementation with the improvements suggested in
11+
[transtrum2012improvements](@citet). Designed for large-scale and numerically-difficult
12+
nonlinear systems.
13+
14+
### Keyword Arguments
15+
16+
- `damping_initial`: the starting value for the damping factor. The damping factor is
17+
inversely proportional to the step size. The damping factor is adjusted during each
18+
iteration. Defaults to `1.0`. See Section 2.1 of [transtrum2012improvements](@citet).
19+
- `damping_increase_factor`: the factor by which the damping is increased if a step is
20+
rejected. Defaults to `2.0`. See Section 2.1 of [transtrum2012improvements](@citet).
21+
- `damping_decrease_factor`: the factor by which the damping is decreased if a step is
22+
accepted. Defaults to `3.0`. See Section 2.1 of [transtrum2012improvements](@citet).
23+
- `min_damping_D`: the minimum value of the damping terms in the diagonal damping matrix
24+
`DᵀD`, where `DᵀD` is given by the largest diagonal entries of `JᵀJ` yet encountered,
25+
where `J` is the Jacobian. It is suggested by [transtrum2012improvements](@citet) to use
26+
a minimum value of the elements in `DᵀD` to prevent the damping from being too small.
27+
Defaults to `1e-8`.
28+
- `disable_geodesic`: Disables Geodesic Acceleration if set to `Val(true)`. It provides
29+
a way to trade-off robustness for speed, though in most situations Geodesic Acceleration
30+
should not be disabled.
31+
32+
For the remaining arguments, see [`GeodesicAcceleration`](@ref) and
33+
[`NonlinearSolve.LevenbergMarquardtTrustRegion`](@ref) documentations.
34+
"""
35+
function LevenbergMarquardt(;
36+
linsolve = nothing, precs = nothing,
37+
damping_initial::Real = 1.0, α_geodesic::Real = 0.75, disable_geodesic = Val(false),
38+
damping_increase_factor::Real = 2.0, damping_decrease_factor::Real = 3.0,
39+
finite_diff_step_geodesic = 0.1, b_uphill::Real = 1.0, min_damping_D::Real = 1e-8,
40+
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing
41+
)
42+
descent = DampedNewtonDescent(;
43+
linsolve,
44+
precs,
45+
initial_damping = damping_initial,
46+
damping_fn = LevenbergMarquardtDampingFunction(
47+
damping_increase_factor, damping_decrease_factor, min_damping_D
48+
)
49+
)
50+
if disable_geodesic isa Val{false}
51+
descent = GeodesicAcceleration(descent, finite_diff_step_geodesic, α_geodesic)
52+
end
53+
trustregion = LevenbergMarquardtTrustRegion(b_uphill)
54+
return GeneralizedFirstOrderAlgorithm(;
55+
trustregion,
56+
descent,
57+
autodiff,
58+
vjp_autodiff,
59+
jvp_autodiff,
60+
name = :LevenbergMarquardt
61+
)
62+
end
63+
64+
@concrete struct LevenbergMarquardtDampingFunction <: AbstractDampingFunction
65+
increase_factor
66+
decrease_factor
67+
min_damping
68+
end
69+
70+
function InternalAPI.init(
71+
prob::AbstractNonlinearProblem, f::LevenbergMarquardtDampingFunction,
72+
initial_damping, J, fu, u, normal_form::Val; kwargs...
73+
)
74+
T = promote_type(eltype(u), eltype(fu))
75+
DᵀD = init_levenberg_marquardt_diagonal(u, T(f.min_damping))
76+
if normal_form isa Val{true}
77+
J_diag_cache = nothing
78+
else
79+
@bb J_diag_cache = similar(u)
80+
end
81+
J_damped = T(initial_damping) .* DᵀD
82+
return LevenbergMarquardtDampingCache(
83+
T(f.increase_factor), T(f.decrease_factor), T(f.min_damping),
84+
T(f.increase_factor), T(initial_damping), DᵀD, J_diag_cache, J_damped, f,
85+
T(initial_damping)
86+
)
87+
end
88+
89+
@concrete mutable struct LevenbergMarquardtDampingCache <: AbstractDampingFunctionCache
90+
increase_factor
91+
decrease_factor
92+
min_damping
93+
λ_factor
94+
λ
95+
DᵀD
96+
J_diag_cache
97+
J_damped
98+
damping_f
99+
initial_damping
100+
end
101+
102+
function InternalAPI.reinit!(cache::LevenbergMarquardtDampingCache, args...; kwargs...)
103+
cache.λ = cache.initial_damping
104+
cache.λ_factor = cache.damping_f.increase_factor
105+
if !(cache.DᵀD isa Number)
106+
if ArrayInterface.can_setindex(cache.DᵀD.diag)
107+
cache.DᵀD.diag .= cache.min_damping
108+
else
109+
cache.DᵀD = Diagonal(ones(typeof(cache.DᵀD.diag)) * cache.min_damping)
110+
end
111+
end
112+
cache.J_damped = cache.λ .* cache.DᵀD
113+
return
114+
end
115+
116+
function NonlinearSolveBase.requires_normal_form_jacobian(::Union{
117+
LevenbergMarquardtDampingFunction, LevenbergMarquardtDampingCache})
118+
return false
119+
end
120+
function NonlinearSolveBase.requires_normal_form_rhs(::Union{
121+
LevenbergMarquardtDampingFunction, LevenbergMarquardtDampingCache})
122+
return false
123+
end
124+
function NonlinearSolveBase.returns_norm_form_damping(::Union{
125+
LevenbergMarquardtDampingFunction, LevenbergMarquardtDampingCache})
126+
return true
127+
end
128+
129+
(damping::LevenbergMarquardtDampingCache)(::Nothing) = damping.J_damped
130+
131+
function InternalAPI.solve!(
132+
cache::LevenbergMarquardtDampingCache, J, fu, ::Val{false}; kwargs...
133+
)
134+
if ArrayInterface.can_setindex(cache.J_diag_cache)
135+
sum!(abs2, Utils.safe_vec(cache.J_diag_cache), J')
136+
elseif cache.J_diag_cache isa Number
137+
cache.J_diag_cache = abs2(J)
138+
else
139+
cache.J_diag_cache = dropdims(sum(abs2, J'; dims = 1); dims = 1)
140+
end
141+
cache.DᵀD = update_levenberg_marquardt_diagonal!!(
142+
cache.DᵀD, Utils.safe_vec(cache.J_diag_cache)
143+
)
144+
@bb @. cache.J_damped = cache.λ * cache.DᵀD
145+
return cache.J_damped
146+
end
147+
148+
function InternalAPI.solve!(
149+
cache::LevenbergMarquardtDampingCache, JᵀJ, fu, ::Val{true}; kwargs...
150+
)
151+
cache.DᵀD = update_levenberg_marquardt_diagonal!!(cache.DᵀD, JᵀJ)
152+
@bb @. cache.J_damped = cache.λ * cache.DᵀD
153+
return cache.J_damped
154+
end
155+
156+
function NonlinearSolveBase.callback_into_cache!(
157+
topcache, cache::LevenbergMarquardtDampingCache, args...
158+
)
159+
if NonlinearSolveBase.last_step_accepted(topcache.trustregion_cache) &&
160+
NonlinearSolveBase.last_step_accepted(topcache.descent_cache)
161+
cache.λ_factor = 1 / cache.decrease_factor
162+
end
163+
cache.λ *= cache.λ_factor
164+
cache.λ_factor = cache.increase_factor
165+
end
166+
167+
"""
168+
LevenbergMarquardtTrustRegion(b_uphill)
169+
170+
Trust Region method for [`LevenbergMarquardt`](@ref). This method is tightly coupled with
171+
the Levenberg-Marquardt method and works by directly updating the damping parameter instead
172+
of specifying a trust region radius.
173+
174+
### Arguments
175+
176+
- `b_uphill`: a factor that determines if a step is accepted or rejected. The standard
177+
choice in the Levenberg-Marquardt method is to accept all steps that decrease the cost
178+
and reject all steps that increase the cost. Although this is a natural and safe choice,
179+
it is often not the most efficient. Therefore downhill moves are always accepted, but
180+
uphill moves are only conditionally accepted. To decide whether an uphill move will be
181+
accepted at each iteration ``i``, we compute
182+
``\\beta_i = \\cos(v_{\\text{new}}, v_{\\text{old}})``, which denotes the cosine angle
183+
between the proposed velocity ``v_{\\text{new}}`` and the velocity of the last accepted
184+
step ``v_{\\text{old}}``. The idea is to accept uphill moves if the angle is small. To
185+
specify, uphill moves are accepted if
186+
``(1-\\beta_i)^{b_{\\text{uphill}}} C_{i+1} \\le C_i``, where ``C_i`` is the cost at
187+
iteration ``i``. Reasonable choices for `b_uphill` are `1.0` or `2.0`, with
188+
`b_uphill = 2.0` allowing higher uphill moves than `b_uphill = 1.0`. When
189+
`b_uphill = 0.0`, no uphill moves will be accepted. Defaults to `1.0`. See Section 4 of
190+
[transtrum2012improvements](@citet).
191+
"""
192+
@concrete struct LevenbergMarquardtTrustRegion <: AbstractTrustRegionMethod
193+
β_uphill
194+
end
195+
196+
function InternalAPI.init(
197+
prob::AbstractNonlinearProblem, alg::LevenbergMarquardtTrustRegion,
198+
f::NonlinearFunction, fu, u, p, args...;
199+
stats, internalnorm::F = L2_NORM, kwargs...
200+
) where {F}
201+
T = promote_type(eltype(u), eltype(fu))
202+
@bb v = copy(u)
203+
@bb u_cache = similar(u)
204+
@bb fu_cache = similar(fu)
205+
return LevenbergMarquardtTrustRegionCache(
206+
f, p, T(Inf), v, T(Inf), internalnorm, T(alg.β_uphill), false,
207+
u_cache, fu_cache, stats
208+
)
209+
end
210+
211+
@concrete mutable struct LevenbergMarquardtTrustRegionCache <:
212+
AbstractTrustRegionMethodCache
213+
f
214+
p
215+
loss_old
216+
v_cache
217+
norm_v_old
218+
internalnorm
219+
β_uphill
220+
last_step_accepted::Bool
221+
u_cache
222+
fu_cache
223+
stats::NLStats
224+
end
225+
226+
function InternalAPI.reinit!(
227+
cache::LevenbergMarquardtTrustRegionCache; p = cache.p, u0 = cache.v_cache, kwargs...
228+
)
229+
cache.p = p
230+
@bb copyto!(cache.v_cache, u0)
231+
cache.loss_old = oftype(cache.loss_old, Inf)
232+
cache.norm_v_old = oftype(cache.norm_v_old, Inf)
233+
cache.last_step_accepted = false
234+
end
235+
236+
function InternalAPI.solve!(
237+
cache::LevenbergMarquardtTrustRegionCache, J, fu, u, δu, descent_stats
238+
)
239+
# This should be true if Geodesic Acceleration is being used
240+
v = hasfield(typeof(descent_stats), :v) ? descent_stats.v : δu
241+
norm_v = cache.internalnorm(v)
242+
β = dot(v, cache.v_cache) / (norm_v * cache.norm_v_old)
243+
244+
@bb @. cache.u_cache = u + δu
245+
cache.fu_cache = Utils.evaluate_f!!(cache.f, cache.fu_cache, cache.u_cache, cache.p)
246+
cache.stats.nf += 1
247+
248+
loss = cache.internalnorm(cache.fu_cache)
249+
250+
if (1 - β)^cache.β_uphill * loss cache.loss_old # Accept Step
251+
cache.last_step_accepted = true
252+
cache.norm_v_old = norm_v
253+
@bb copyto!(cache.v_cache, v)
254+
else
255+
cache.last_step_accepted = false
256+
end
257+
258+
return cache.last_step_accepted, cache.u_cache, cache.fu_cache
259+
end
260+
261+
update_levenberg_marquardt_diagonal!!(y::Number, x::Number) = max(y, x)
262+
function update_levenberg_marquardt_diagonal!!(y::Diagonal, x::AbstractVecOrMat)
263+
if ArrayInterface.can_setindex(y.diag)
264+
if ArrayInterface.fast_scalar_indexing(y.diag)
265+
if ndims(x) == 1
266+
@simd ivdep for i in axes(x, 1)
267+
@inbounds y.diag[i] = max(y.diag[i], x[i])
268+
end
269+
else
270+
@simd ivdep for i in axes(x, 1)
271+
@inbounds y.diag[i] = max(y.diag[i], x[i, i])
272+
end
273+
end
274+
else
275+
if ndims(x) == 1
276+
@. y.diag = max(y.diag, x)
277+
else
278+
y.diag .= max.(y.diag, @view(x[diagind(x)]))
279+
end
280+
end
281+
return y
282+
end
283+
ndims(x) == 1 && return Diagonal(max.(y.diag, x))
284+
return Diagonal(max.(y.diag, @view(x[diagind(x)])))
285+
end
286+
287+
init_levenberg_marquardt_diagonal(u::Number, v) = oftype(u, v)
288+
init_levenberg_marquardt_diagonal(u::SArray, v) = Diagonal(ones(typeof(vec(u))) * v)
289+
function init_levenberg_marquardt_diagonal(u, v)
290+
d = similar(vec(u))
291+
d .= v
292+
return Diagonal(d)
293+
end

lib/NonlinearSolveFirstOrder/src/solve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ function SciMLBase.__init(
184184
NonlinearSolveBase.supports_trust_region(alg.descent) ||
185185
error("Trust Region not supported by $(alg.descent).")
186186
trustregion_cache = InternalAPI.init(
187-
prob, alg.trustregion, f, fu, u, p; stats, internalnorm, kwargs...
187+
prob, alg.trustregion, prob.f, fu, u, prob.p;
188+
stats, internalnorm, kwargs...
188189
)
189190
globalization = Val(:TrustRegion)
190191
end

0 commit comments

Comments
 (0)