Skip to content

Commit 8134a83

Browse files
committed
Add the ability to store the trace for nonlinear solve algorithms
1 parent 46912f2 commit 8134a83

File tree

4 files changed

+132
-8
lines changed

4 files changed

+132
-8
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
1616
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1717
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1818
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
19+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1920
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2021
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2122
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
@@ -56,6 +57,7 @@ LinearAlgebra = "<0.0.1, 1"
5657
LinearSolve = "2.12"
5758
NonlinearProblemLibrary = "0.1"
5859
PrecompileTools = "1"
60+
Printf = "<0.0.1, 1"
5961
RecursiveArrayTools = "2"
6062
Reexport = "0.2, 1"
6163
SciMLBase = "2.8.2"

src/NonlinearSolve.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import Reexport: @reexport
88
import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload
99

1010
@recompile_invalidations begin
11-
using DiffEqBase, LinearAlgebra, LinearSolve, SparseArrays, SparseDiffTools
11+
using DiffEqBase, LinearAlgebra, LinearSolve, Printf, SparseArrays, SparseDiffTools
1212
using FastBroadcast: @..
1313
import ArrayInterface: restructure
1414

@@ -79,11 +79,14 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
7979
end
8080
end
8181

82+
trace = _getproperty(cache, Val{:trace}())
83+
8284
return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);
83-
cache.retcode, cache.stats)
85+
cache.retcode, cache.stats, trace)
8486
end
8587

8688
include("utils.jl")
89+
include("trace.jl")
8790
include("extension_algs.jl")
8891
include("linesearch.jl")
8992
include("raphson.jl")

src/raphson.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ for large-scale and numerically-difficult nonlinear systems.
3030
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
3131
used here directly, and they will be converted to the correct `LineSearch`.
3232
"""
33-
@concrete struct NewtonRaphson{CJ, AD} <:
34-
AbstractNewtonAlgorithm{CJ, AD}
33+
@concrete struct NewtonRaphson{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
3534
ad::AD
3635
linsolve
3736
precs
@@ -72,12 +71,13 @@ end
7271
stats::NLStats
7372
ls_cache
7473
tc_cache
74+
trace
7575
end
7676

7777
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphson, args...;
7878
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
79-
termination_condition = nothing, internalnorm = DEFAULT_NORM,
80-
linsolve_kwargs = (;), kwargs...) where {uType, iip}
79+
termination_condition = nothing, internalnorm = DEFAULT_NORM, linsolve_kwargs = (;),
80+
kwargs...) where {uType, iip}
8181
alg = get_concrete_algorithm(alg_, prob)
8282
@unpack f, u0, p = prob
8383
u = alias_u0 ? u0 : deepcopy(u0)
@@ -88,10 +88,12 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso
8888
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu1, u,
8989
termination_condition)
9090

91+
ls_cache = init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip))
92+
trace = init_nonlinearsolve_trace(u, fu1, J, du; kwargs...)
93+
9194
return NewtonRaphsonCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve, J,
9295
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
93-
NLStats(1, 0, 0, 0, 0),
94-
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)), tc_cache)
96+
NLStats(1, 0, 0, 0, 0), ls_cache, tc_cache, trace)
9597
end
9698

9799
function perform_step!(cache::NewtonRaphsonCache{true})
@@ -108,6 +110,8 @@ function perform_step!(cache::NewtonRaphsonCache{true})
108110
_axpy!(-α, du, u)
109111
f(cache.fu1, u, p)
110112

113+
update_trace!(cache.trace, cache.stats.nsteps + 1, u, cache.fu1, J, du)
114+
111115
check_and_update!(cache, cache.fu1, cache.u, cache.u_prev)
112116

113117
@. u_prev = u
@@ -136,6 +140,9 @@ function perform_step!(cache::NewtonRaphsonCache{false})
136140
cache.u = @. u - α * cache.du # `u` might not support mutation
137141
cache.fu1 = f(cache.u, p)
138142

143+
update_trace!(cache.trace, cache.stats.nsteps + 1, cache.u, cache.fu1, cache.J,
144+
cache.du)
145+
139146
check_and_update!(cache, cache.fu1, cache.u, cache.u_prev)
140147

141148
cache.u_prev = cache.u

src/trace.jl

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
@concrete struct NonlinearSolveTraceEntry
2+
iteration::Int
3+
fnorm
4+
stepnorm
5+
condJ
6+
J
7+
u
8+
fu
9+
δu
10+
end
11+
12+
function __show_top_level(io::IO, entry::NonlinearSolveTraceEntry)
13+
if entry.condJ === nothing
14+
@printf io "%-6s %-20s %-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm"
15+
@printf io "%-6s %-20s %-20s\n" "----" "-------------" "-----------"
16+
else
17+
@printf io "%-6s %-20s %-20s %-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm" "cond(J)"
18+
@printf io "%-6s %-20s %-20s %-20s\n" "----" "-------------" "-----------" "-------"
19+
end
20+
end
21+
22+
function Base.show(io::IO, entry::NonlinearSolveTraceEntry)
23+
entry.iteration == 0 && __show_top_level(io, entry)
24+
if entry.condJ === nothing
25+
@printf io "%-6d %-20.8e %-20.8e\n" entry.iteration entry.fnorm entry.stepnorm
26+
else
27+
@printf io "%-6d %-20.8e %-20.8e %-20.8e\n" entry.iteration entry.fnorm entry.stepnorm entry.condJ
28+
end
29+
return nothing
30+
end
31+
32+
function NonlinearSolveTraceEntry(iteration, fu, δu)
33+
return NonlinearSolveTraceEntry(iteration, norm(fu, Inf), norm(δu, 2), nothing,
34+
nothing, nothing, nothing, nothing)
35+
end
36+
37+
function NonlinearSolveTraceEntry(iteration, fu, δu, J)
38+
return NonlinearSolveTraceEntry(iteration, norm(fu, Inf), norm(δu, 2), __cond(J), nothing,
39+
nothing, nothing, nothing)
40+
end
41+
42+
function NonlinearSolveTraceEntry(iteration, fu, δu, J, u)
43+
return NonlinearSolveTraceEntry(iteration, norm(fu, Inf), norm(δu, 2), __cond(J),
44+
copy(J), copy(u), copy(fu), copy(δu))
45+
end
46+
47+
__cond(J::AbstractMatrix) = cond(J)
48+
__cond(J) = NaN # Covers cases where `J` is a Operator, nothing, etc.
49+
50+
@concrete struct NonlinearSolveTrace{show_trace, trace_level, store_trace}
51+
history
52+
end
53+
54+
function Base.show(io::IO, trace::NonlinearSolveTrace)
55+
for entry in trace.history
56+
show(io, entry)
57+
end
58+
return nothing
59+
end
60+
61+
function init_nonlinearsolve_trace(u, fu, J, δu; show_trace::Val = Val(false),
62+
trace_level::Val = Val(1), store_trace::Val = Val(false), kwargs...)
63+
return init_nonlinearsolve_trace(show_trace, trace_level, store_trace, u, fu, J, δu)
64+
end
65+
66+
function init_nonlinearsolve_trace(::Val{show_trace}, ::Val{trace_level},
67+
::Val{store_trace}, u, fu, J, δu) where {show_trace, trace_level, store_trace}
68+
history = __init_trace_history(Val{show_trace}(), Val{trace_level}(),
69+
Val{store_trace}(), u, fu, J, δu)
70+
return NonlinearSolveTrace{show_trace, trace_level, store_trace}(history)
71+
end
72+
73+
function __init_trace_history(::Val{show_trace}, ::Val{trace_level}, ::Val{store_trace}, u,
74+
fu, J, δu) where {show_trace, trace_level, store_trace}
75+
!store_trace && !show_trace && return nothing
76+
entry = __trace_entry(Val{trace_level}(), 0, u, fu, J, δu)
77+
show_trace && show(entry)
78+
store_trace && return [entry]
79+
return nothing
80+
end
81+
82+
__trace_entry(::Val{1}, iter, u, fu, J, δu) = NonlinearSolveTraceEntry(iter, fu, δu)
83+
__trace_entry(::Val{2}, iter, u, fu, J, δu) = NonlinearSolveTraceEntry(iter, fu, δu, J)
84+
__trace_entry(::Val{3}, iter, u, fu, J, δu) = NonlinearSolveTraceEntry(iter, fu, δu, J, u)
85+
function __trace_entry(::Val{T}, iter, u, fu, J, δu) where {T}
86+
throw(ArgumentError("::Val{trace_level} == ::Val{$(T)} is not supported. \
87+
Possible values are `Val{1}()`/`Val{2}()`/`Val{3}()`."))
88+
end
89+
90+
function update_trace!(trace::NonlinearSolveTrace{ShT, TrL, StT}, iter, u, fu, J,
91+
δu) where {ShT, TrL, StT}
92+
!StT && !ShT && return nothing
93+
entry = __trace_entry(Val{TrL}(), iter, u, fu, J, δu)
94+
StT && push!(trace.history, entry)
95+
ShT && show(entry)
96+
return trace
97+
end
98+
99+
# Needed for Algorithms which directly use `inv(J)` instead of `J`
100+
function update_trace_with_invJ!(trace::NonlinearSolveTrace{ShT, TrL, StT}, iter, u, fu, J,
101+
δu) where {ShT, TrL, StT}
102+
!StT && !ShT && return nothing
103+
if TrL == 1
104+
entry = __trace_entry(Val{1}(), iter, u, fu, J, δu)
105+
else
106+
entry = __trace_entry(Val{TrL}(), iter, u, fu, inv(J), δu)
107+
end
108+
StT && push!(trace.history, entry)
109+
ShT && show(entry)
110+
return trace
111+
end
112+

0 commit comments

Comments
 (0)