Skip to content

Commit 8fb2284

Browse files
Merge pull request #292 from avik-pal/ap/trace
Add the ability to store the trace for nonlinear solve algorithms
2 parents 46912f2 + 1b9c188 commit 8fb2284

17 files changed

+499
-42
lines changed

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1212
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
1313
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1414
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
15+
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1516
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
1617
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1718
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1819
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
20+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1921
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2022
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2123
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
@@ -43,22 +45,24 @@ ADTypes = "0.2"
4345
ArrayInterface = "6.0.24, 7"
4446
BandedMatrices = "1"
4547
ConcreteStructs = "0.2"
46-
DiffEqBase = "6.136"
48+
DiffEqBase = "6.141"
4749
EnumX = "1"
4850
Enzyme = "0.11"
4951
FastBroadcast = "0.1.9, 0.2"
5052
FastLevenbergMarquardt = "0.1"
5153
FiniteDiff = "2"
5254
ForwardDiff = "0.10.3"
55+
LazyArrays = "1.8"
5356
LeastSquaresOptim = "0.8"
5457
LineSearches = "7"
5558
LinearAlgebra = "<0.0.1, 1"
5659
LinearSolve = "2.12"
5760
NonlinearProblemLibrary = "0.1"
5861
PrecompileTools = "1"
62+
Printf = "<0.0.1, 1"
5963
RecursiveArrayTools = "2"
6064
Reexport = "0.2, 1"
61-
SciMLBase = "2.8.2"
65+
SciMLBase = "2.9"
6266
SciMLOperators = "0.3"
6367
SimpleNonlinearSolve = "0.1.23"
6468
SparseArrays = "<0.0.1, 1"

docs/pages.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pages = ["index.md",
1212
"basics/solve.md",
1313
"basics/NonlinearSolution.md",
1414
"basics/TerminationCondition.md",
15+
"basics/Logging.md",
1516
"basics/FAQ.md"],
1617
"Solver Summaries and Recommendations" => Any["solvers/NonlinearSystemSolvers.md",
1718
"solvers/BracketingSolvers.md",

docs/src/basics/Logging.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Logging the Solve Process
2+
3+
All NonlinearSolve.jl native solvers allow storing and displaying the trace of the nonlinear
4+
solve process. This is controlled by 3 keyword arguments to `solve`:
5+
6+
1. `show_trace`: Must be `Val(true)` or `Val(false)`. This controls whether the trace is
7+
displayed to the console. (Defaults to `Val(false)`)
8+
2. `trace_level`: Needs to be one of Trace Objects: [`TraceMinimal`](@ref),
9+
[`TraceWithJacobianConditionNumber`](@ref), or [`TraceAll`](@ref). This controls the
10+
level of detail of the trace. (Defaults to `TraceMinimal()`)
11+
3. `store_trace`: Must be `Val(true)` or `Val(false)`. This controls whether the trace is
12+
stored in the solution object. (Defaults to `Val(false)`)
13+
14+
## Example Usage
15+
16+
```@example tracing
17+
using ModelingToolkit, NonlinearSolve
18+
19+
@variables x y z
20+
@parameters σ ρ β
21+
22+
# Define a nonlinear system
23+
eqs = [0 ~ σ * (y - x),
24+
0 ~ x * (ρ - z) - y,
25+
0 ~ x * y - β * z]
26+
@named ns = NonlinearSystem(eqs, [x, y, z], [σ, ρ, β])
27+
28+
u0 = [x => 1.0, y => 0.0, z => 0.0]
29+
30+
ps = [σ => 10.0 ρ => 26.0 β => 8 / 3]
31+
32+
prob = NonlinearProblem(ns, u0, ps)
33+
34+
solve(prob)
35+
```
36+
37+
This produced the output, but it is hard to diagnose what is going on. We can turn on
38+
the trace to see what is happening:
39+
40+
```@example tracing
41+
solve(prob; show_trace = Val(true), trace_level = TraceAll(10))
42+
```
43+
44+
You can also store the trace in the solution object:
45+
46+
```@example tracing
47+
sol = solve(prob; trace_level = TraceAll(), store_trace = Val(true));
48+
49+
sol.trace
50+
```
51+
52+
!!! note
53+
54+
For `iteration == 0` only the `norm(fu, Inf)` is guaranteed to be meaningful. The other
55+
values being meaningful are solver dependent.
56+
57+
## API
58+
59+
```@docs
60+
TraceMinimal
61+
TraceWithJacobianConditionNumber
62+
TraceAll
63+
```

src/NonlinearSolve.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ import Reexport: @reexport
88
import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload
99

1010
@recompile_invalidations begin
11-
using DiffEqBase, LinearAlgebra, LinearSolve, SparseArrays, SparseDiffTools
11+
using DiffEqBase,
12+
LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
13+
SparseDiffTools
1214
using FastBroadcast: @..
1315
import ArrayInterface: restructure
1416

@@ -50,6 +52,32 @@ abstract type AbstractNonlinearSolveCache{iip} end
5052

5153
isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip
5254

55+
function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
56+
str = "$(nameof(typeof(alg)))("
57+
modifiers = String[]
58+
if _getproperty(alg, Val(:ad)) !== nothing
59+
push!(modifiers, "ad = $(nameof(typeof(alg.ad)))()")
60+
end
61+
if _getproperty(alg, Val(:linsolve)) !== nothing
62+
push!(modifiers, "linsolve = $(nameof(typeof(alg.linsolve)))()")
63+
end
64+
if _getproperty(alg, Val(:linesearch)) !== nothing
65+
ls = alg.linesearch
66+
if ls isa LineSearch
67+
ls.method !== nothing &&
68+
push!(modifiers, "linesearch = $(nameof(typeof(ls.method)))()")
69+
else
70+
push!(modifiers, "linesearch = $(nameof(typeof(alg.linesearch)))()")
71+
end
72+
end
73+
if _getproperty(alg, Val(:radius_update_scheme)) !== nothing
74+
push!(modifiers, "radius_update_scheme = $(alg.radius_update_scheme)")
75+
end
76+
str = str * join(modifiers, ", ")
77+
print(io, "$(str))")
78+
return nothing
79+
end
80+
5381
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
5482
alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...)
5583
cache = init(prob, alg, args...; kwargs...)
@@ -79,11 +107,18 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
79107
end
80108
end
81109

110+
trace = _getproperty(cache, Val{:trace}())
111+
if trace !== nothing
112+
update_trace!(trace, cache.stats.nsteps, get_u(cache), get_fu(cache), nothing,
113+
nothing, nothing; last = Val(true))
114+
end
115+
82116
return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);
83-
cache.retcode, cache.stats)
117+
cache.retcode, cache.stats, trace)
84118
end
85119

86120
include("utils.jl")
121+
include("trace.jl")
87122
include("extension_algs.jl")
88123
include("linesearch.jl")
89124
include("raphson.jl")
@@ -162,4 +197,7 @@ export SteadyStateDiffEqTerminationMode, SimpleNonlinearSolveTerminationMode,
162197
AbsNormTerminationMode, RelSafeTerminationMode, AbsSafeTerminationMode,
163198
RelSafeBestTerminationMode, AbsSafeBestTerminationMode
164199

200+
# Tracing Functionality
201+
export TraceAll, TraceMinimal, TraceWithJacobianConditionNumber
202+
165203
end # module

src/broyden.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Sadly `Broyden` is taken up by SimpleNonlinearSolve.jl
22
"""
3-
GeneralBroyden(; max_resets = 3, linesearch = LineSearch(), reset_tolerance = nothing)
3+
GeneralBroyden(; max_resets = 3, linesearch = nothing, reset_tolerance = nothing)
44
55
An implementation of `Broyden` with reseting and line search.
66
@@ -21,7 +21,7 @@ An implementation of `Broyden` with reseting and line search.
2121
linesearch
2222
end
2323

24-
function GeneralBroyden(; max_resets = 3, linesearch = LineSearch(),
24+
function GeneralBroyden(; max_resets = 3, linesearch = nothing,
2525
reset_tolerance = nothing)
2626
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
2727
return GeneralBroyden(max_resets, reset_tolerance, linesearch)
@@ -54,6 +54,7 @@ end
5454
stats::NLStats
5555
ls_cache
5656
tc_cache
57+
trace
5758
end
5859

5960
get_fu(cache::GeneralBroydenCache) = cache.fu
@@ -66,19 +67,22 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
6667
@unpack f, u0, p = prob
6768
u = alias_u0 ? u0 : deepcopy(u0)
6869
fu = evaluate_f(prob, u)
70+
du = _mutable_zero(u)
6971
J⁻¹ = __init_identity_jacobian(u, fu)
7072
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(real(eltype(u)))) :
7173
alg.reset_tolerance
7274
reset_check = x -> abs(x) reset_tolerance
7375

7476
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
7577
termination_condition)
78+
trace = init_nonlinearsolve_trace(alg, u, fu, J⁻¹, du; uses_jac_inverse = Val(true),
79+
kwargs...)
7680

77-
return GeneralBroydenCache{iip}(f, alg, u, zero(u), _mutable_zero(u), fu, zero(fu),
81+
return GeneralBroydenCache{iip}(f, alg, u, zero(u), du, fu, zero(fu),
7882
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
7983
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
8084
reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
81-
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache)
85+
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
8286
end
8387

8488
function perform_step!(cache::GeneralBroydenCache{true})
@@ -90,6 +94,9 @@ function perform_step!(cache::GeneralBroydenCache{true})
9094
_axpy!(-α, du, u)
9195
f(fu2, u, p)
9296

97+
update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
98+
get_fu(cache), J⁻¹, du, α)
99+
93100
check_and_update!(cache, fu2, u, u_prev)
94101
cache.stats.nf += 1
95102

@@ -131,6 +138,9 @@ function perform_step!(cache::GeneralBroydenCache{false})
131138
cache.u = cache.u .- α * cache.du
132139
cache.fu2 = f(cache.u, p)
133140

141+
update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
142+
get_fu(cache), cache.J⁻¹, cache.du, α)
143+
134144
check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)
135145
cache.stats.nf += 1
136146

@@ -173,6 +183,7 @@ function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = ca
173183
cache.fu = cache.f(cache.u, p)
174184
end
175185

186+
reset!(cache.trace)
176187
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
177188
termination_condition)
178189

src/default.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ end
8282
fu = get_fu($(cache_syms[i]))
8383
return SciMLBase.build_solution($(sol_syms[i]).prob, cache.alg, u,
8484
fu; retcode = ReturnCode.Success, stats,
85-
original = $(sol_syms[i]))
85+
original = $(sol_syms[i]), trace = $(sol_syms[i]).trace)
8686
end
8787
cache.current = $(i + 1)
8888
end
@@ -103,7 +103,7 @@ end
103103
u = cache.caches[idx].u
104104

105105
return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u,
106-
fus[idx]; retcode, stats)
106+
fus[idx]; retcode, stats, cache.caches[idx].trace)
107107
end)
108108

109109
return Expr(:block, calls...)
@@ -125,7 +125,7 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
125125
if SciMLBase.successful_retcode($(cur_sol))
126126
return SciMLBase.build_solution(prob, alg, $(cur_sol).u,
127127
$(cur_sol).resid; $(cur_sol).retcode, $(cur_sol).stats,
128-
original = $(cur_sol))
128+
original = $(cur_sol), trace = $(cur_sol).trace)
129129
end
130130
end)
131131
end
@@ -147,7 +147,7 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
147147
if idx == $i
148148
return SciMLBase.build_solution(prob, alg, $(sol_syms[i]).u,
149149
$(sol_syms[i]).resid; $(sol_syms[i]).retcode,
150-
$(sol_syms[i]).stats)
150+
$(sol_syms[i]).stats, $(sol_syms[i]).trace)
151151
end
152152
end)
153153
end

src/dfsane.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ end
9090
prob
9191
stats::NLStats
9292
tc_cache
93+
trace
9394
end
9495

9596
get_fu(cache::DFSaneCache) = cache.fu
@@ -113,11 +114,12 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args.
113114

114115
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, uprev,
115116
termination_condition)
117+
trace = init_nonlinearsolve_trace(alg, u, fu, nothing, du; kwargs...)
116118

117119
return DFSaneCache{iip}(alg, u, uprev, fu, fuprev, du, history, f_norm, f_norm_0, alg.M,
118120
T(alg.σ_1), T(alg.σ_min), T(alg.σ_max), one(T), T(alg.γ), T(alg.τ_min),
119121
T(alg.τ_max), alg.n_exp, prob.p, false, maxiters, internalnorm, ReturnCode.Default,
120-
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache)
122+
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache, trace)
121123
end
122124

123125
function perform_step!(cache::DFSaneCache{true})
@@ -164,6 +166,9 @@ function perform_step!(cache::DFSaneCache{true})
164166
f_norm = cache.internalnorm(cache.fu)^n_exp
165167
end
166168

169+
update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), nothing,
170+
cache.du, α₊)
171+
167172
check_and_update!(cache, cache.fu, cache.u, cache.uprev)
168173

169174
# Update spectral parameter
@@ -236,6 +241,9 @@ function perform_step!(cache::DFSaneCache{false})
236241
f_norm = cache.internalnorm(cache.fu)^n_exp
237242
end
238243

244+
update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), nothing,
245+
cache.du, α₊)
246+
239247
check_and_update!(cache, cache.fu, cache.u, cache.uprev)
240248

241249
# Update spectral parameter
@@ -288,6 +296,7 @@ function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.u; p = cache.p,
288296
T = eltype(cache.u)
289297
cache.σ_n = T(cache.alg.σ_1)
290298

299+
reset!(cache.trace)
291300
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
292301
termination_condition)
293302

0 commit comments

Comments
 (0)