Skip to content

Commit 6e2f58d

Browse files
committed
Allow Levenberg to work with NonlinearLeastSquaresProblem
1 parent 30f60a5 commit 6e2f58d

File tree

6 files changed

+54
-42
lines changed

6 files changed

+54
-42
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ NonlinearProblemLibrary = "0.1"
3939
PrecompileTools = "1"
4040
RecursiveArrayTools = "2"
4141
Reexport = "0.2, 1"
42-
SciMLBase = "1.97, 2"
42+
SciMLBase = "2"
4343
SimpleNonlinearSolve = "0.1"
4444
SparseDiffTools = "2.6"
4545
StaticArraysCore = "1.4"

src/NonlinearSolve.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,19 @@ abstract type AbstractNonlinearSolveCache{iip} end
3232

3333
isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip
3434

35-
function SciMLBase.__solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAlgorithm,
36-
args...; kwargs...)
35+
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
36+
alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...)
3737
cache = init(prob, alg, args...; kwargs...)
3838
return solve!(cache)
3939
end
4040

41+
function not_terminated(cache::AbstractNonlinearSolveCache)
42+
return !cache.force_stop && cache.stats.nsteps < cache.maxiters
43+
end
44+
get_fu(cache::AbstractNonlinearSolveCache) = cache.fu1
45+
4146
function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
42-
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
47+
while not_terminated(cache)
4348
perform_step!(cache)
4449
cache.stats.nsteps += 1
4550
end
@@ -50,7 +55,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
5055
cache.retcode = ReturnCode.Success
5156
end
5257

53-
return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1;
58+
return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, get_fu(cache);
5459
cache.retcode, cache.stats)
5560
end
5661

src/levenberg.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ numerically-difficult nonlinear systems.
7272
where `J` is the Jacobian. It is suggested by
7373
[this paper](https://arxiv.org/abs/1201.5885) to use a minimum value of the elements in
7474
`DᵀD` to prevent the damping from being too small. Defaults to `1e-8`.
75+
76+
!!! warning
77+
78+
`linsolve` and `precs` are used exclusively for the inplace version of the algorithm.
79+
Support for the OOP version is planned!
7580
"""
7681
@concrete struct LevenbergMarquardt{CJ, AD, T} <: AbstractNewtonAlgorithm{CJ, AD}
7782
ad::AD
@@ -135,11 +140,14 @@ end
135140
loss_old::lossType
136141
make_new_J::Bool
137142
fu_tmp
143+
u_tmp
144+
Jv
138145
mat_tmp::jType
139146
stats::NLStats
140147
end
141148

142-
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarquardt,
149+
function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
150+
NonlinearLeastSquaresProblem{uType, iip}}, alg::LevenbergMarquardt,
143151
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
144152
linsolve_kwargs = (;), kwargs...) where {uType, iip}
145153
@unpack f, u0, p = prob
@@ -166,21 +174,21 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarq
166174
end
167175

168176
loss = internalnorm(fu1)
169-
JᵀJ = zero(J)
177+
JᵀJ = J isa Number ? zero(J) : similar(J, size(J, 2), size(J, 2))
170178
v = zero(u)
171179
a = zero(u)
172180
tmp_vec = zero(u)
173181
v_old = zero(u)
174182
δ = zero(u)
175183
make_new_J = true
176184
fu_tmp = zero(fu1)
177-
mat_tmp = zero(J)
185+
mat_tmp = zero(JᵀJ)
178186

179187
return LevenbergMarquardtCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
180188
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, DᵀD,
181189
JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic,
182190
b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp,
183-
mat_tmp, NLStats(1, 0, 0, 0, 0))
191+
zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0))
184192
end
185193

186194
function perform_step!(cache::LevenbergMarquardtCache{true})
@@ -200,10 +208,10 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
200208
@unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache
201209

202210
# Usual Levenberg-Marquardt step ("velocity").
203-
# The following lines do: cache.v = -cache.mat_tmp \ cache.fu_tmp
204-
mul!(cache.fu_tmp, J', fu1)
211+
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
212+
mul!(cache.u_tmp, J', fu1)
205213
@. cache.mat_tmp = JᵀJ + λ * DᵀD
206-
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.fu_tmp),
214+
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.u_tmp),
207215
linu = _vec(cache.du), p = p, reltol = cache.abstol)
208216
cache.linsolve = linres.cache
209217
@. cache.v = -cache.du
@@ -213,8 +221,8 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
213221
f(cache.fu_tmp, u .+ h .* v, p)
214222

215223
# The following lines do: cache.a = -J \ cache.fu_tmp
216-
mul!(cache.du, J, v)
217-
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.du)
224+
mul!(cache.Jv, J, v)
225+
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv)
218226
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.fu_tmp),
219227
linu = _vec(cache.du), p = p, reltol = cache.abstol)
220228
cache.linsolve = linres.cache

src/raphson.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -127,18 +127,22 @@ function perform_step!(cache::NewtonRaphsonCache{false})
127127
return nothing
128128
end
129129

130-
function SciMLBase.solve!(cache::NewtonRaphsonCache)
131-
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
132-
perform_step!(cache)
133-
cache.stats.nsteps += 1
134-
end
135-
136-
if cache.stats.nsteps == cache.maxiters
137-
cache.retcode = ReturnCode.MaxIters
130+
function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cache.p,
131+
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
132+
cache.p = p
133+
if iip
134+
recursivecopy!(cache.u, u0)
135+
cache.f(cache.fu1, cache.u, p)
138136
else
139-
cache.retcode = ReturnCode.Success
137+
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
138+
cache.u = u0
139+
cache.fu1 = cache.f(cache.u, p)
140140
end
141-
142-
return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1;
143-
cache.retcode, cache.stats)
141+
cache.abstol = abstol
142+
cache.maxiters = maxiters
143+
cache.stats.nf = 1
144+
cache.stats.nsteps = 1
145+
cache.force_stop = false
146+
cache.retcode = ReturnCode.Default
147+
return cache
144148
end

src/trustRegion.jl

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ for large-scale and numerically-difficult nonlinear systems.
141141
`expand_threshold < r` (with `r` defined in `shrink_threshold`). Defaults to `2.0`.
142142
- `max_shrink_times`: the maximum number of times to shrink the trust region radius in a
143143
row, `max_shrink_times` is exceeded, the algorithm returns. Defaults to `32`.
144+
145+
!!! warning
146+
147+
`linsolve` and `precs` are used exclusively for the inplace version of the algorithm.
148+
Support for the OOP version is planned!
144149
"""
145150
@concrete struct TrustRegion{CJ, AD, MTR} <: AbstractNewtonAlgorithm{CJ, AD}
146151
ad::AD
@@ -662,22 +667,11 @@ function jvp!(cache::TrustRegionCache{true})
662667
return g
663668
end
664669

665-
function SciMLBase.solve!(cache::TrustRegionCache)
666-
while !cache.force_stop && cache.stats.nsteps < cache.maxiters &&
667-
cache.shrink_counter < cache.alg.max_shrink_times
668-
perform_step!(cache)
669-
cache.stats.nsteps += 1
670-
end
671-
672-
if cache.stats.nsteps == cache.maxiters
673-
cache.retcode = ReturnCode.MaxIters
674-
else
675-
cache.retcode = ReturnCode.Success
676-
end
677-
678-
return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu; cache.retcode,
679-
cache.stats)
670+
function not_terminated(cache::TrustRegionCache)
671+
return !cache.force_stop && cache.stats.nsteps < cache.maxiters &&
672+
cache.shrink_counter < cache.alg.max_shrink_times
680673
end
674+
get_fu(cache::TrustRegionCache) = cache.fu
681675

682676
function SciMLBase.reinit!(cache::TrustRegionCache{iip}, u0 = cache.u; p = cache.p,
683677
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}

src/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ _maybe_mutable(x, ::AutoSparseEnzyme) = _mutable(x)
144144
_maybe_mutable(x, _) = x
145145

146146
# Helper function to get value of `f(u, p)`
147-
function evaluate_f(prob::NonlinearProblem{uType, iip}, u) where {uType, iip}
147+
function evaluate_f(prob::Union{NonlinearProblem{uType, iip},
148+
NonlinearLeastSquaresProblem{uType, iip}}, u) where {uType, iip}
148149
@unpack f, u0, p = prob
149150
if iip
150151
fu = f.resid_prototype === nothing ? zero(u) : f.resid_prototype

0 commit comments

Comments
 (0)