Skip to content

Commit c6c8ce9

Browse files
author
oscarddssmith
committed
use NLStats
1 parent 753173e commit c6c8ce9

File tree

6 files changed

+67
-46
lines changed

6 files changed

+67
-46
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1111
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
14+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1415
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1516
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1617
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1718
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
18-
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1919
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2020
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
2121
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
@@ -28,11 +28,11 @@ EnumX = "1"
2828
FiniteDiff = "2"
2929
ForwardDiff = "0.10.3"
3030
LinearSolve = "2"
31+
PrecompileTools = "1"
3132
RecursiveArrayTools = "2"
3233
Reexport = "0.2, 1"
33-
SciMLBase = "1.73"
34+
SciMLBase = "1.92.4"
3435
SimpleNonlinearSolve = "0.1"
35-
PrecompileTools = "1"
3636
SparseDiffTools = "1, 2"
3737
StaticArraysCore = "1.4"
3838
UnPack = "1.0"

src/NonlinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using DiffEqBase
1717
using SparseDiffTools
1818

1919
@reexport using SciMLBase
20+
using SciMLBase: NLStats
2021
@reexport using SimpleNonlinearSolve
2122

2223
import SciMLBase: _unwrap_val

src/levenberg.jl

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ mutable struct LevenbergMarquardtCache{iip, fType, algType, uType, duType, resTy
158158
J::jType
159159
du_tmp::duType
160160
jac_config::JC
161-
iter::Int
162161
force_stop::Bool
163162
maxiters::Int
164163
internalnorm::INType
@@ -185,10 +184,11 @@ mutable struct LevenbergMarquardtCache{iip, fType, algType, uType, duType, resTy
185184
make_new_J::Bool
186185
fu_tmp::resType
187186
mat_tmp::jType
187+
stats::NLStats
188188

189189
function LevenbergMarquardtCache{iip}(f::fType, alg::algType, u::uType, fu::resType,
190190
p::pType, uf::ufType, linsolve::L, J::jType,
191-
du_tmp::duType, jac_config::JC, iter::Int,
191+
du_tmp::duType, jac_config::JC,
192192
force_stop::Bool, maxiters::Int,
193193
internalnorm::INType,
194194
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
@@ -202,7 +202,7 @@ mutable struct LevenbergMarquardtCache{iip, fType, algType, uType, duType, resTy
202202
norm_v_old::lossType, δ::uType,
203203
loss_old::lossType, make_new_J::Bool,
204204
fu_tmp::resType,
205-
mat_tmp::jType) where {
205+
mat_tmp::jType, stats::NLStats) where {
206206
iip, fType, algType,
207207
uType, duType, resType,
208208
pType, INType, tolType,
@@ -213,15 +213,15 @@ mutable struct LevenbergMarquardtCache{iip, fType, algType, uType, duType, resTy
213213
new{iip, fType, algType, uType, duType, resType,
214214
pType, INType, tolType, probType, ufType, L,
215215
jType, JC, DᵀDType, λType, lossType}(f, alg, u, fu, p, uf, linsolve, J, du_tmp,
216-
jac_config, iter, force_stop, maxiters,
216+
jac_config, force_stop, maxiters,
217217
internalnorm, retcode, abstol, prob, DᵀD,
218218
JᵀJ, λ, λ_factor,
219219
damping_increase_factor,
220220
damping_decrease_factor, h,
221221
α_geodesic, b_uphill, min_damping_D,
222222
v, a, tmp_vec, v_old,
223223
norm_v_old, δ, loss_old, make_new_J,
224-
fu_tmp, mat_tmp)
224+
fu_tmp, mat_tmp, stats)
225225
end
226226
end
227227

@@ -301,14 +301,13 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarq
301301
mat_tmp = zero(J)
302302

303303
return LevenbergMarquardtCache{iip}(f, alg, u, fu, p, uf, linsolve, J, du_tmp,
304-
jac_config,
305-
1, false, maxiters, internalnorm,
304+
jac_config, false, maxiters, internalnorm,
306305
ReturnCode.Default, abstol, prob, DᵀD, JᵀJ,
307306
λ, λ_factor, damping_increase_factor,
308307
damping_decrease_factor, h,
309308
α_geodesic, b_uphill, min_damping_D,
310309
v, a, tmp_vec, v_old, loss, δ, loss, make_new_J,
311-
fu_tmp, mat_tmp)
310+
fu_tmp, mat_tmp, NLStats(1,0,0,0,0))
312311
end
313312
function perform_step!(cache::LevenbergMarquardtCache{true})
314313
@unpack fu, f, make_new_J = cache
@@ -321,6 +320,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
321320
mul!(cache.JᵀJ, cache.J', cache.J)
322321
cache.DᵀD .= max.(cache.DᵀD, Diagonal(cache.JᵀJ))
323322
cache.make_new_J = false
323+
cache.stats.njacs += 1
324324
end
325325
@unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache
326326

@@ -344,13 +344,16 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
344344
linu = _vec(cache.du_tmp), p = p, reltol = cache.abstol)
345345
cache.linsolve = linres.cache
346346
@. cache.a = -cache.du_tmp
347+
cache.stats.nsolve += 2
348+
cache.stats.nfactors += 2
347349

348350
# Require acceptable steps to satisfy the following condition.
349351
norm_v = norm(v)
350352
if (2 * norm(cache.a) / norm_v) < α_geodesic
351353
@. cache.δ = v + cache.a / 2
352354
@unpack δ, loss_old, norm_v_old, v_old, b_uphill = cache
353355
f(cache.fu_tmp, u .+ δ, p)
356+
cache.stats.nf += 1
354357
loss = cache.internalnorm(cache.fu_tmp)
355358

356359
# Condition to accept uphill steps (evaluates to `loss ≤ loss_old` in iteration 1).
@@ -390,6 +393,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
390393
cache.DᵀD .= max.(cache.DᵀD, Diagonal(cache.JᵀJ))
391394
end
392395
cache.make_new_J = false
396+
cache.stats.njacs += 1
393397
end
394398
@unpack u, p, λ, JᵀJ, DᵀD, J = cache
395399

@@ -399,13 +403,16 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
399403
@unpack v, h, α_geodesic = cache
400404
# Geodesic acceleration (step_size = v + a / 2).
401405
cache.a = -J \ ((2 / h) .* ((f(u .+ h .* v, p) .- fu) ./ h .- J * v))
406+
cache.stats.nsolve += 1
407+
cache.stats.nfactors += 1
402408

403409
# Require acceptable steps to satisfy the following condition.
404410
norm_v = norm(v)
405411
if (2 * norm(cache.a) / norm_v) < α_geodesic
406412
cache.δ = v .+ cache.a ./ 2
407413
@unpack δ, loss_old, norm_v_old, v_old, b_uphill = cache
408414
fu_new = f(u .+ δ, p)
415+
cache.stats.nf += 1
409416
loss = cache.internalnorm(fu_new)
410417

411418
# Condition to accept uphill steps (evaluates to `loss ≤ loss_old` in iteration 1).
@@ -431,17 +438,17 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
431438
end
432439

433440
function SciMLBase.solve!(cache::LevenbergMarquardtCache)
434-
while !cache.force_stop && cache.iter < cache.maxiters
441+
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
435442
perform_step!(cache)
436-
cache.iter += 1
443+
cache.stats.nsteps += 1
437444
end
438445

439-
if cache.iter == cache.maxiters
446+
if cache.stats.nsteps == cache.maxiters
440447
cache.retcode = ReturnCode.MaxIters
441448
else
442449
cache.retcode = ReturnCode.Success
443450
end
444451

445452
SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu;
446-
retcode = cache.retcode)
453+
retcode = cache.retcode, stats = cache.stats)
447454
end

src/raphson.jl

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,30 +75,28 @@ mutable struct NewtonRaphsonCache{iip, fType, algType, uType, duType, resType, p
7575
J::jType
7676
du1::duType
7777
jac_config::JC
78-
iter::Int
7978
force_stop::Bool
8079
maxiters::Int
8180
internalnorm::INType
8281
retcode::SciMLBase.ReturnCode.T
8382
abstol::tolType
8483
prob::probType
84+
stats::NLStats
8585

8686
function NewtonRaphsonCache{iip}(f::fType, alg::algType, u::uType, fu::resType,
87-
p::pType,
88-
uf::ufType, linsolve::L, J::jType, du1::duType,
89-
jac_config::JC, iter::Int,
90-
force_stop::Bool, maxiters::Int, internalnorm::INType,
87+
p::pType, uf::ufType, linsolve::L, J::jType, du1::duType,
88+
jac_config::JC, force_stop::Bool, maxiters::Int, internalnorm::INType,
9189
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
92-
prob::probType) where {
90+
prob::probType, stats::NLStats) where {
9391
iip, fType, algType, uType,
9492
duType, resType, pType, INType,
9593
tolType,
9694
probType, ufType, L, jType, JC}
9795
new{iip, fType, algType, uType, duType, resType, pType, INType, tolType,
9896
probType, ufType, L, jType, JC}(f, alg, u, fu, p,
99-
uf, linsolve, J, du1, jac_config, iter,
97+
uf, linsolve, J, du1, jac_config,
10098
force_stop, maxiters, internalnorm,
101-
retcode, abstol, prob)
99+
retcode, abstol, prob, stats)
102100
end
103101
end
104102

@@ -150,8 +148,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson
150148
uf, linsolve, J, du1, jac_config = jacobian_caches(alg, f, u, p, Val(iip))
151149

152150
return NewtonRaphsonCache{iip}(f, alg, u, fu, p, uf, linsolve, J, du1, jac_config,
153-
1, false, maxiters, internalnorm,
154-
ReturnCode.Default, abstol, prob)
151+
false, maxiters, internalnorm,
152+
ReturnCode.Default, abstol, prob, NLStats(1,0,0,0,0))
155153
end
156154

157155
function perform_step!(cache::NewtonRaphsonCache{true})
@@ -169,6 +167,10 @@ function perform_step!(cache::NewtonRaphsonCache{true})
169167
if cache.internalnorm(cache.fu) < cache.abstol
170168
cache.force_stop = true
171169
end
170+
cache.stats.nf += 1
171+
cache.stats.njacs += 1
172+
cache.stats.nsolve += 1
173+
cache.stats.nfactors += 1
172174
return nothing
173175
end
174176

@@ -180,23 +182,27 @@ function perform_step!(cache::NewtonRaphsonCache{false})
180182
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
181183
cache.force_stop = true
182184
end
185+
cache.stats.nf += 1
186+
cache.stats.njacs += 1
187+
cache.stats.nsolve += 1
188+
cache.stats.nfactors += 1
183189
return nothing
184190
end
185191

186192
function SciMLBase.solve!(cache::NewtonRaphsonCache)
187-
while !cache.force_stop && cache.iter < cache.maxiters
193+
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
188194
perform_step!(cache)
189-
cache.iter += 1
195+
cache.stats.nsteps += 1
190196
end
191197

192-
if cache.iter == cache.maxiters
198+
if cache.stats.nsteps == cache.maxiters
193199
cache.retcode = ReturnCode.MaxIters
194200
else
195201
cache.retcode = ReturnCode.Success
196202
end
197203

198204
SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu;
199-
retcode = cache.retcode)
205+
retcode = cache.retcode, stats = cache.stats)
200206
end
201207

202208
function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cache.p,
@@ -212,7 +218,8 @@ function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cac
212218
end
213219
cache.abstol = abstol
214220
cache.maxiters = maxiters
215-
cache.iter = 1
221+
cache.stats.nf = 1
222+
cache.stats.nsteps = 1
216223
cache.force_stop = false
217224
cache.retcode = ReturnCode.Default
218225
return cache

0 commit comments

Comments
 (0)