Skip to content

Commit 354e080

Browse files
committed
Support tracing for all and add an example
1 parent be7b44e commit 354e080

16 files changed

+397
-84
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1212
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
1313
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1414
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
15+
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1516
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
1617
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1718
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
@@ -44,13 +45,14 @@ ADTypes = "0.2"
4445
ArrayInterface = "6.0.24, 7"
4546
BandedMatrices = "1"
4647
ConcreteStructs = "0.2"
47-
DiffEqBase = "6.136"
48+
DiffEqBase = "6.141"
4849
EnumX = "1"
4950
Enzyme = "0.11"
5051
FastBroadcast = "0.1.9, 0.2"
5152
FastLevenbergMarquardt = "0.1"
5253
FiniteDiff = "2"
5354
ForwardDiff = "0.10.3"
55+
LazyArrays = "1.8"
5456
LeastSquaresOptim = "0.8"
5557
LineSearches = "7"
5658
LinearAlgebra = "<0.0.1, 1"
@@ -60,7 +62,7 @@ PrecompileTools = "1"
6062
Printf = "<0.0.1, 1"
6163
RecursiveArrayTools = "2"
6264
Reexport = "0.2, 1"
63-
SciMLBase = "2.8.2"
65+
SciMLBase = "2.9"
6466
SciMLOperators = "0.3"
6567
SimpleNonlinearSolve = "0.1.23"
6668
SparseArrays = "<0.0.1, 1"

docs/pages.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pages = ["index.md",
1212
"basics/solve.md",
1313
"basics/NonlinearSolution.md",
1414
"basics/TerminationCondition.md",
15+
"basics/Logging.md",
1516
"basics/FAQ.md"],
1617
"Solver Summaries and Recommendations" => Any["solvers/NonlinearSystemSolvers.md",
1718
"solvers/BracketingSolvers.md",

docs/src/basics/Logging.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# [Logging the Solve Process](@ logging_api)
2+
3+
All NonlinearSolve.jl native solvers allow storing and displaying the trace of the nonlinear
4+
solve process. This is controlled by 3 keyword arguments to `solve`:
5+
6+
1. `show_trace`: Must be `Val(true)` or `Val(false)`. This controls whether the trace is
7+
displayed to the console. (Defaults to `Val(false)`)
8+
2. `trace_level`: Needs to be one of Trace Objects: [`TraceMinimal`](@ref),
9+
[`TraceWithJacobianConditionNumber`](@ref), or [`TraceAll`](@ref). This controls the
10+
level of detail of the trace. (Defaults to `TraceMinimal()`)
11+
3. `store_trace`: Must be `Val(true)` or `Val(false)`. This controls whether the trace is
12+
stored in the solution object. (Defaults to `Val(false)`)
13+
14+
## Example Usage
15+
16+
```@example tracing
17+
using ModelingToolkit, NonlinearSolve
18+
19+
@variables x y z
20+
@parameters σ ρ β
21+
22+
# Define a nonlinear system
23+
eqs = [0 ~ σ * (y - x),
24+
0 ~ x * (ρ - z) - y,
25+
0 ~ x * y - β * z]
26+
@named ns = NonlinearSystem(eqs, [x, y, z], [σ, ρ, β])
27+
28+
u0 = [x => 1.0, y => 0.0, z => 0.0]
29+
30+
ps = [σ => 10.0 ρ => 26.0 β => 8 / 3]
31+
32+
prob = NonlinearProblem(ns, u0, ps)
33+
34+
solve(prob)
35+
```
36+
37+
This produced the output, but it is hard to diagnose what is going on. We can turn on
38+
the trace to see what is happening:
39+
40+
```@example tracing
41+
solve(prob; show_trace = Val(true), trace_level = TraceAll(10))
42+
```
43+
44+
You can also store the trace in the solution object:
45+
46+
```@example tracing
47+
sol = solve(prob; trace_level = TraceAll(), store_trace = Val(true));
48+
49+
sol.trace
50+
```
51+
52+
## API
53+
54+
```@docs
55+
TraceMinimal
56+
TraceWithJacobianConditionNumber
57+
TraceAll
58+
```

src/NonlinearSolve.jl

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ import Reexport: @reexport
88
import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload
99

1010
@recompile_invalidations begin
11-
using DiffEqBase, LinearAlgebra, LinearSolve, Printf, SparseArrays, SparseDiffTools
11+
using DiffEqBase,
12+
LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
13+
SparseDiffTools
1214
using FastBroadcast: @..
1315
import ArrayInterface: restructure
1416

@@ -50,6 +52,32 @@ abstract type AbstractNonlinearSolveCache{iip} end
5052

5153
isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip
5254

55+
function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
56+
str = "$(nameof(typeof(alg)))("
57+
modifiers = String[]
58+
if _getproperty(alg, Val(:ad)) !== nothing
59+
push!(modifiers, "ad = $(nameof(typeof(alg.ad)))()")
60+
end
61+
if _getproperty(alg, Val(:linsolve)) !== nothing
62+
push!(modifiers, "linsolve = $(nameof(typeof(alg.linsolve)))()")
63+
end
64+
if _getproperty(alg, Val(:linesearch)) !== nothing
65+
ls = alg.linesearch
66+
if ls isa LineSearch
67+
ls.method !== nothing &&
68+
push!(modifiers, "linesearch = $(nameof(typeof(ls.method)))()")
69+
else
70+
push!(modifiers, "linesearch = $(nameof(typeof(alg.linesearch)))()")
71+
end
72+
end
73+
if _getproperty(alg, Val(:radius_update_scheme)) !== nothing
74+
push!(modifiers, "radius_update_scheme = $(alg.radius_update_scheme)")
75+
end
76+
str = str * join(modifiers, ", ")
77+
print(io, "$(str))")
78+
return nothing
79+
end
80+
5381
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
5482
alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...)
5583
cache = init(prob, alg, args...; kwargs...)
@@ -80,6 +108,10 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
80108
end
81109

82110
trace = _getproperty(cache, Val{:trace}())
111+
if trace !== nothing
112+
update_trace!(trace, cache.stats.nsteps, get_u(cache), get_fu(cache), nothing,
113+
nothing, nothing; last = Val(true))
114+
end
83115

84116
return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);
85117
cache.retcode, cache.stats, trace)
@@ -165,4 +197,7 @@ export SteadyStateDiffEqTerminationMode, SimpleNonlinearSolveTerminationMode,
165197
AbsNormTerminationMode, RelSafeTerminationMode, AbsSafeTerminationMode,
166198
RelSafeBestTerminationMode, AbsSafeBestTerminationMode
167199

200+
# Tracing Functionality
201+
export TraceAll, TraceMinimal, TraceWithJacobianConditionNumber
202+
168203
end # module

src/broyden.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Sadly `Broyden` is taken up by SimpleNonlinearSolve.jl
22
"""
3-
GeneralBroyden(; max_resets = 3, linesearch = LineSearch(), reset_tolerance = nothing)
3+
GeneralBroyden(; max_resets = 3, linesearch = nothing, reset_tolerance = nothing)
44
55
An implementation of `Broyden` with reseting and line search.
66
@@ -21,7 +21,7 @@ An implementation of `Broyden` with reseting and line search.
2121
linesearch
2222
end
2323

24-
function GeneralBroyden(; max_resets = 3, linesearch = LineSearch(),
24+
function GeneralBroyden(; max_resets = 3, linesearch = nothing,
2525
reset_tolerance = nothing)
2626
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
2727
return GeneralBroyden(max_resets, reset_tolerance, linesearch)
@@ -54,6 +54,7 @@ end
5454
stats::NLStats
5555
ls_cache
5656
tc_cache
57+
trace
5758
end
5859

5960
get_fu(cache::GeneralBroydenCache) = cache.fu
@@ -66,19 +67,22 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
6667
@unpack f, u0, p = prob
6768
u = alias_u0 ? u0 : deepcopy(u0)
6869
fu = evaluate_f(prob, u)
70+
du = _mutable_zero(u)
6971
J⁻¹ = __init_identity_jacobian(u, fu)
7072
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(real(eltype(u)))) :
7173
alg.reset_tolerance
7274
reset_check = x -> abs(x) reset_tolerance
7375

7476
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
7577
termination_condition)
78+
trace = init_nonlinearsolve_trace(alg, u, fu, J⁻¹, du; uses_jac_inverse = Val(true),
79+
kwargs...)
7680

77-
return GeneralBroydenCache{iip}(f, alg, u, zero(u), _mutable_zero(u), fu, zero(fu),
81+
return GeneralBroydenCache{iip}(f, alg, u, zero(u), du, fu, zero(fu),
7882
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
7983
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
8084
reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
81-
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache)
85+
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
8286
end
8387

8488
function perform_step!(cache::GeneralBroydenCache{true})
@@ -90,6 +94,9 @@ function perform_step!(cache::GeneralBroydenCache{true})
9094
_axpy!(-α, du, u)
9195
f(fu2, u, p)
9296

97+
update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
98+
get_fu(cache), J⁻¹, du, α)
99+
93100
check_and_update!(cache, fu2, u, u_prev)
94101
cache.stats.nf += 1
95102

@@ -131,6 +138,9 @@ function perform_step!(cache::GeneralBroydenCache{false})
131138
cache.u = cache.u .- α * cache.du
132139
cache.fu2 = f(cache.u, p)
133140

141+
update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
142+
get_fu(cache), cache.J⁻¹, cache.du, α)
143+
134144
check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)
135145
cache.stats.nf += 1
136146

@@ -173,6 +183,7 @@ function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = ca
173183
cache.fu = cache.f(cache.u, p)
174184
end
175185

186+
reset!(cache.trace)
176187
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
177188
termination_condition)
178189

src/default.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ end
8282
fu = get_fu($(cache_syms[i]))
8383
return SciMLBase.build_solution($(sol_syms[i]).prob, cache.alg, u,
8484
fu; retcode = ReturnCode.Success, stats,
85-
original = $(sol_syms[i]))
85+
original = $(sol_syms[i]), trace = $(sol_syms[i]).trace)
8686
end
8787
cache.current = $(i + 1)
8888
end
@@ -103,7 +103,7 @@ end
103103
u = cache.caches[idx].u
104104

105105
return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u,
106-
fus[idx]; retcode, stats)
106+
fus[idx]; retcode, stats, cache.caches[idx].trace)
107107
end)
108108

109109
return Expr(:block, calls...)
@@ -125,7 +125,7 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
125125
if SciMLBase.successful_retcode($(cur_sol))
126126
return SciMLBase.build_solution(prob, alg, $(cur_sol).u,
127127
$(cur_sol).resid; $(cur_sol).retcode, $(cur_sol).stats,
128-
original = $(cur_sol))
128+
original = $(cur_sol), trace = $(cur_sol).trace)
129129
end
130130
end)
131131
end
@@ -147,7 +147,7 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
147147
if idx == $i
148148
return SciMLBase.build_solution(prob, alg, $(sol_syms[i]).u,
149149
$(sol_syms[i]).resid; $(sol_syms[i]).retcode,
150-
$(sol_syms[i]).stats)
150+
$(sol_syms[i]).stats, $(sol_syms[i]).trace)
151151
end
152152
end)
153153
end

src/dfsane.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ end
9090
prob
9191
stats::NLStats
9292
tc_cache
93+
trace
9394
end
9495

9596
get_fu(cache::DFSaneCache) = cache.fu
@@ -113,11 +114,12 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args.
113114

114115
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, uprev,
115116
termination_condition)
117+
trace = init_nonlinearsolve_trace(alg, u, fu, nothing, du; kwargs...)
116118

117119
return DFSaneCache{iip}(alg, u, uprev, fu, fuprev, du, history, f_norm, f_norm_0, alg.M,
118120
T(alg.σ_1), T(alg.σ_min), T(alg.σ_max), one(T), T(alg.γ), T(alg.τ_min),
119121
T(alg.τ_max), alg.n_exp, prob.p, false, maxiters, internalnorm, ReturnCode.Default,
120-
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache)
122+
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache, trace)
121123
end
122124

123125
function perform_step!(cache::DFSaneCache{true})
@@ -164,6 +166,9 @@ function perform_step!(cache::DFSaneCache{true})
164166
f_norm = cache.internalnorm(cache.fu)^n_exp
165167
end
166168

169+
update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), nothing,
170+
cache.du, α₊)
171+
167172
check_and_update!(cache, cache.fu, cache.u, cache.uprev)
168173

169174
# Update spectral parameter
@@ -236,6 +241,9 @@ function perform_step!(cache::DFSaneCache{false})
236241
f_norm = cache.internalnorm(cache.fu)^n_exp
237242
end
238243

244+
update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), nothing,
245+
cache.du, α₊)
246+
239247
check_and_update!(cache, cache.fu, cache.u, cache.uprev)
240248

241249
# Update spectral parameter
@@ -288,6 +296,7 @@ function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.u; p = cache.p,
288296
T = eltype(cache.u)
289297
cache.σ_n = T(cache.alg.σ_1)
290298

299+
reset!(cache.trace)
291300
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
292301
termination_condition)
293302

src/gaussnewton.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
GaussNewton(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(),
2+
GaussNewton(; concrete_jac = nothing, linsolve = nothing, linesearch = nothing,
33
precs = DEFAULT_PRECS, adkwargs...)
44
55
An advanced GaussNewton implementation with support for efficient handling of sparse
@@ -47,7 +47,7 @@ function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
4747
end
4848

4949
function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
50-
linesearch = LineSearch(), precs = DEFAULT_PRECS, vjp_autodiff = nothing,
50+
linesearch = nothing, precs = DEFAULT_PRECS, vjp_autodiff = nothing,
5151
adkwargs...)
5252
ad = default_adargs_to_adtype(; adkwargs...)
5353
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
@@ -82,6 +82,7 @@ end
8282
tc_cache_1
8383
tc_cache_2
8484
ls_cache
85+
trace
8586
end
8687

8788
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton,
@@ -108,11 +109,12 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
108109
abstol, reltol, tc_cache_1 = init_termination_cache(abstol, reltol, fu1, u,
109110
termination_condition)
110111
_, _, tc_cache_2 = init_termination_cache(abstol, reltol, fu1, u, termination_condition)
112+
trace = init_nonlinearsolve_trace(alg, u, fu1, J, du; kwargs...)
111113

112114
return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
113115
linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default,
114116
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2,
115-
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)))
117+
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)), trace)
116118
end
117119

118120
function perform_step!(cache::GaussNewtonCache{true})
@@ -137,6 +139,9 @@ function perform_step!(cache::GaussNewtonCache{true})
137139
_axpy!(-α, du, u)
138140
f(cache.fu_new, u, p)
139141

142+
update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), J,
143+
cache.du, α)
144+
140145
check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev)
141146
if !cache.force_stop
142147
cache.fu1 .= cache.fu_new .- cache.fu1
@@ -179,6 +184,9 @@ function perform_step!(cache::GaussNewtonCache{false})
179184
cache.u = @. u - α * cache.du # `u` might not support mutation
180185
cache.fu_new = f(cache.u, p)
181186

187+
update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), cache.J,
188+
cache.du, α)
189+
182190
check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev)
183191
if !cache.force_stop
184192
cache.fu1 = cache.fu_new .- cache.fu1
@@ -207,6 +215,7 @@ function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache
207215
cache.fu1 = cache.f(cache.u, p)
208216
end
209217

218+
reset!(cache.trace)
210219
abstol, reltol, tc_cache_1 = init_termination_cache(abstol, reltol, cache.fu1, cache.u,
211220
termination_condition)
212221
_, _, tc_cache_2 = init_termination_cache(abstol, reltol, cache.fu1, cache.u,

0 commit comments

Comments
 (0)