Skip to content

Commit a30f80b

Browse files
committed
refactor: move tracing functionality to NonlinearSolveBase
1 parent 5ca1074 commit a30f80b

File tree

11 files changed

+289
-233
lines changed

11 files changed

+289
-233
lines changed

ext/NonlinearSolveLeastSquaresOptimExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ module NonlinearSolveLeastSquaresOptimExt
22

33
using ConcreteStructs: @concrete
44
using LeastSquaresOptim: LeastSquaresOptim
5-
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
6-
using NonlinearSolve: NonlinearSolve, LeastSquaresOptimJL, TraceMinimal
5+
using NonlinearSolveBase: NonlinearSolveBase, TraceMinimal, get_tolerance
6+
using NonlinearSolve: NonlinearSolve, LeastSquaresOptimJL
77
using SciMLBase: SciMLBase, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode
88

99
const LSO = LeastSquaresOptim

ext/NonlinearSolveNLsolveExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
module NonlinearSolveNLsolveExt
22

33
using LineSearches: Static
4-
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
5-
using NonlinearSolve: NonlinearSolve, NLsolveJL, TraceMinimal
4+
using NonlinearSolveBase: NonlinearSolveBase, TraceMinimal, get_tolerance
5+
using NonlinearSolve: NonlinearSolve, NLsolveJL
66
using NLsolve: NLsolve, OnceDifferentiable, nlsolve
77
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
88

@@ -32,7 +32,7 @@ function SciMLBase.__solve(
3232
abstol = get_tolerance(abstol, eltype(u0))
3333
show_trace = ShT
3434
store_trace = StT
35-
extended_trace = !(trace_level isa TraceMinimal)
35+
extended_trace = !(trace_level.trace_mode isa Val{:minimal})
3636

3737
linesearch = alg.linesearch === missing ? Static() : alg.linesearch
3838

lib/NonlinearSolveBase/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1818
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1919
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
2020
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
21+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2122
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2223
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2324
SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e"
@@ -64,6 +65,7 @@ LinearSolve = "2.36.1"
6465
Markdown = "1.10"
6566
MaybeInplace = "0.1.4"
6667
Preferences = "1.4"
68+
Printf = "1.10"
6769
RecursiveArrayTools = "3"
6870
SciMLBase = "2.50"
6971
SciMLJacobianOperators = "0.1.1"

lib/NonlinearSolveBase/ext/NonlinearSolveBaseSparseArraysExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,8 @@ Utils.maybe_symmetric(x::AbstractSparseMatrix) = x
1313

1414
Utils.make_sparse(x) = sparse(x)
1515

16+
Utils.condition_number(J::AbstractSparseMatrix) = Utils.condition_number(Matrix(J))
17+
18+
Utils.maybe_pinv!!_workspace(A::AbstractSparseMatrix) = Matrix(A)
19+
1620
end

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind
1515
using Markdown: @doc_str
1616
using MaybeInplace: @bb
1717
using Preferences: @load_preference
18+
using Printf: @printf
1819
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
1920
using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem,
2021
AbstractNonlinearAlgorithm, AbstractNonlinearFunction,
@@ -39,6 +40,7 @@ include("autodiff.jl")
3940
include("jacobian.jl")
4041
include("linear_solve.jl")
4142
include("timer_outputs.jl")
43+
include("tracing.jl")
4244

4345
include("descent/common.jl")
4446
include("descent/newton.jl")
@@ -59,6 +61,8 @@ include("descent/geodesic_acceleration.jl")
5961
@compat(public, (construct_linear_solver, needs_square_A, needs_concrete_A))
6062
@compat(public, (construct_jacobian_cache,))
6163

64+
export TraceMinimal, TraceWithJacobianConditionNumber, TraceAll
65+
6266
export RelTerminationMode, AbsTerminationMode,
6367
NormTerminationMode, RelNormTerminationMode, AbsNormTerminationMode,
6468
RelNormSafeTerminationMode, AbsNormSafeTerminationMode,

lib/NonlinearSolveBase/src/abstract_types.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ concrete_jac(::Val{true}) = true
220220

221221
abstract type AbstractNonlinearSolveCache <: AbstractNonlinearSolveBaseAPI end
222222

223+
function get_u end
224+
function get_fu end
225+
223226
"""
224227
AbstractLinearSolverCache
225228

lib/NonlinearSolveBase/src/tracing.jl

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
@concrete struct NonlinearSolveTracing
2+
trace_mode <: Union{Val{:minimal}, Val{:condition_number}, Val{:all}}
3+
print_frequency::Int
4+
store_frequency::Int
5+
end
6+
7+
"""
8+
TraceMinimal(freq)
9+
TraceMinimal(; print_frequency = 1, store_frequency::Int = 1)
10+
11+
Trace Minimal Information
12+
13+
1. Iteration Number
14+
2. f(u) inf-norm
15+
3. Step 2-norm
16+
17+
See also [`TraceWithJacobianConditionNumber`](@ref) and [`TraceAll`](@ref).
18+
"""
19+
function TraceMinimal(; print_frequency = 1, store_frequency::Int = 1)
20+
return NonlinearSolveTracing(Val(:minimal), print_frequency, store_frequency)
21+
end
22+
23+
"""
24+
TraceWithJacobianConditionNumber(freq)
25+
TraceWithJacobianConditionNumber(; print_frequency = 1, store_frequency::Int = 1)
26+
27+
[`TraceMinimal`](@ref) + Print the Condition Number of the Jacobian.
28+
29+
See also [`TraceMinimal`](@ref) and [`TraceAll`](@ref).
30+
"""
31+
function TraceWithJacobianConditionNumber(; print_frequency = 1, store_frequency::Int = 1)
32+
return NonlinearSolveTracing(Val(:condition_number), print_frequency, store_frequency)
33+
end
34+
35+
"""
36+
TraceAll(freq)
37+
TraceAll(; print_frequency = 1, store_frequency::Int = 1)
38+
39+
[`TraceWithJacobianConditionNumber`](@ref) + Store the Jacobian, u, f(u), and δu.
40+
41+
!!! warning
42+
43+
This is very expensive and makes copyies of the Jacobian, u, f(u), and δu.
44+
45+
See also [`TraceMinimal`](@ref) and [`TraceWithJacobianConditionNumber`](@ref).
46+
"""
47+
function TraceAll(; print_frequency = 1, store_frequency::Int = 1)
48+
return NonlinearSolveTracing(Val(:all), print_frequency, store_frequency)
49+
end
50+
51+
for Tr in (:TraceMinimal, :TraceWithJacobianConditionNumber, :TraceAll)
52+
@eval $(Tr)(freq) = $(Tr)(; print_frequency = freq, store_frequency = freq)
53+
end
54+
55+
# NonlinearSolve Tracing Utilities
56+
@concrete struct NonlinearSolveTraceEntry
57+
iteration::Int
58+
fnorm
59+
stepnorm
60+
condJ
61+
storage
62+
norm_type::Symbol
63+
end
64+
65+
function Base.getproperty(entry::NonlinearSolveTraceEntry, sym::Symbol)
66+
hasfield(typeof(entry), sym) && return getfield(entry, sym)
67+
return getproperty(entry.storage, sym)
68+
end
69+
70+
function print_top_level(io::IO, entry::NonlinearSolveTraceEntry)
71+
if entry.condJ === nothing
72+
@printf io "%-8s\t%-20s\t%-20s\n" "----" "-------------" "-----------"
73+
if entry.norm_type === :L2
74+
@printf io "%-8s\t%-20s\t%-20s\n" "Iter" "f(u) 2-norm" "Step 2-norm"
75+
else
76+
@printf io "%-8s\t%-20s\t%-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm"
77+
end
78+
@printf io "%-8s\t%-20s\t%-20s\n" "----" "-------------" "-----------"
79+
else
80+
@printf io "%-8s\t%-20s\t%-20s\t%-20s\n" "----" "-------------" "-----------" "-------"
81+
if entry.norm_type === :L2
82+
@printf io "%-8s\t%-20s\t%-20s\t%-20s\n" "Iter" "f(u) 2-norm" "Step 2-norm" "cond(J)"
83+
else
84+
@printf io "%-8s\t%-20s\t%-20s\t%-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm" "cond(J)"
85+
end
86+
@printf io "%-8s\t%-20s\t%-20s\t%-20s\n" "----" "-------------" "-----------" "-------"
87+
end
88+
end
89+
90+
function Base.show(io::IO, ::MIME"text/plain", entry::NonlinearSolveTraceEntry)
91+
entry.iteration == 0 && print_top_level(io, entry)
92+
if entry.iteration < 0 # Special case for final entry
93+
@printf io "%-8s\t%-20.8e\n" "Final" entry.fnorm
94+
@printf io "%-28s\n" "----------------------"
95+
elseif entry.condJ === nothing
96+
@printf io "%-8d\t%-20.8e\t%-20.8e\n" entry.iteration entry.fnorm entry.stepnorm
97+
else
98+
@printf io "%-8d\t%-20.8e\t%-20.8e\t%-20.8e\n" entry.iteration entry.fnorm entry.stepnorm entry.condJ
99+
end
100+
end
101+
102+
function NonlinearSolveTraceEntry(prob::AbstractNonlinearProblem, iteration, fu, δu, J, u)
103+
norm_type = ifelse(prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
104+
fnorm = prob isa NonlinearLeastSquaresProblem ? L2_NORM(fu) : Linf_NORM(fu)
105+
condJ = J !== missing ? Utils.condition_number(J) : nothing
106+
storage = u === missing ? nothing :
107+
(; u = copy(u), fu = copy(fu), δu = copy(δu), J = copy(J))
108+
return NonlinearSolveTraceEntry(
109+
iteration, fnorm, L2_NORM(δu), condJ, storage, norm_type
110+
)
111+
end
112+
113+
@concrete struct NonlinearSolveTrace
114+
show_trace <: Union{Val{false}, Val{true}}
115+
store_trace <: Union{Val{false}, Val{true}}
116+
history
117+
trace_level <: NonlinearSolveTracing
118+
prob
119+
end
120+
121+
reset!(trace::NonlinearSolveTrace) = reset!(trace.history)
122+
reset!(::Nothing) = nothing
123+
reset!(history::Vector) = empty!(history)
124+
125+
function Base.show(io::IO, ::MIME"text/plain", trace::NonlinearSolveTrace)
126+
if trace.history !== nothing
127+
foreach(trace.history) do entry
128+
show(io, MIME"text/plain"(), entry)
129+
end
130+
else
131+
print(io, "Tracing Disabled")
132+
end
133+
end
134+
135+
function init_nonlinearsolve_trace(
136+
prob, alg, u, fu, J, δu; show_trace::Val = Val(false),
137+
trace_level::NonlinearSolveTracing = TraceMinimal(), store_trace::Val = Val(false),
138+
uses_jac_inverse = Val(false), kwargs...
139+
)
140+
return init_nonlinearsolve_trace(
141+
prob, alg, show_trace, trace_level, store_trace, u, fu, J, δu, uses_jac_inverse
142+
)
143+
end
144+
145+
function init_nonlinearsolve_trace(
146+
prob::AbstractNonlinearProblem, alg, show_trace::Val,
147+
trace_level::NonlinearSolveTracing, store_trace::Val, u, fu, J, δu,
148+
uses_jac_inverse::Val
149+
)
150+
if show_trace isa Val{true}
151+
print("\nAlgorithm: ")
152+
Base.printstyled(alg, "\n\n"; color = :green, bold = true)
153+
end
154+
J = uses_jac_inverse isa Val{true} ?
155+
(trace_level.trace_mode isa Val{:minimal} ? J : LinearAlgebra.pinv(J)) : J
156+
history = init_trace_history(prob, show_trace, trace_level, store_trace, u, fu, J, δu)
157+
return NonlinearSolveTrace(show_trace, store_trace, history, trace_level, prob)
158+
end
159+
160+
function init_trace_history(
161+
prob::AbstractNonlinearProblem, show_trace::Val, trace_level,
162+
store_trace::Val, u, fu, J, δu
163+
)
164+
store_trace isa Val{false} && show_trace isa Val{false} && return nothing
165+
entry = if trace_level.trace_mode isa Val{:minimal}
166+
NonlinearSolveTraceEntry(prob, 0, fu, δu, missing, missing)
167+
elseif trace_level.trace_mode isa Val{:condition_number}
168+
NonlinearSolveTraceEntry(prob, 0, fu, δu, J, missing)
169+
else
170+
NonlinearSolveTraceEntry(prob, 0, fu, δu, J, u)
171+
end
172+
show_trace isa Val{true} && show(stdout, MIME"text/plain"(), entry)
173+
store_trace isa Val{true} && return NonlinearSolveTraceEntry[entry]
174+
return nothing
175+
end
176+
177+
function update_trace!(
178+
trace::NonlinearSolveTrace, iter, u, fu, J, δu, α = true; last::Val = Val(false)
179+
)
180+
trace.store_trace isa Val{false} && trace.show_trace isa Val{false} && return nothing
181+
182+
if last isa Val{true}
183+
norm_type = ifelse(trace.prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
184+
fnorm = trace.prob isa NonlinearLeastSquaresProblem ? L2_NORM(fu) : Linf_NORM(fu)
185+
entry = NonlinearSolveTraceEntry(-1, fnorm, NaN32, nothing, nothing, norm_type)
186+
trace.show_trace isa Val{true} && show(stdout, MIME"text/plain"(), entry)
187+
return trace
188+
end
189+
190+
show_now = trace.show_trace isa Val{true} &&
191+
(mod1(iter, trace.trace_level.print_frequency) == 1)
192+
store_now = trace.store_trace isa Val{true} &&
193+
(mod1(iter, trace.trace_level.store_frequency) == 1)
194+
if show_now || store_now
195+
entry = if trace.trace_level.trace_mode isa Val{:minimal}
196+
NonlinearSolveTraceEntry(trace.prob, iter, fu, δu .* α, missing, missing)
197+
elseif trace.trace_level.trace_mode isa Val{:condition_number}
198+
NonlinearSolveTraceEntry(trace.prob, iter, fu, δu .* α, J, missing)
199+
else
200+
NonlinearSolveTraceEntry(trace.prob, iter, fu, δu .* α, J, u)
201+
end
202+
show_now && show(stdout, MIME"text/plain"(), entry)
203+
store_now && push!(trace.history, entry)
204+
end
205+
return trace
206+
end
207+
208+
function update_trace!(cache, α = true)
209+
trace = Utils.safe_getproperty(cache, Val(:trace))
210+
trace === missing && return nothing
211+
212+
J = Utils.safe_getproperty(cache, Val(:J))
213+
if J === missing
214+
update_trace!(
215+
trace, cache.nsteps + 1, get_u(cache), get_fu(cache), nothing, cache.du, α
216+
)
217+
# XXX: Implement
218+
# elseif cache isa ApproximateJacobianSolveCache && store_inverse_jacobian(cache)
219+
# update_trace!(trace, cache.nsteps + 1, get_u(cache), get_fu(cache),
220+
# ApplyArray(__safe_inv, J), cache.du, α)
221+
else
222+
update_trace!(trace, cache.nsteps + 1, get_u(cache), get_fu(cache), J, cache.du, α)
223+
end
224+
end

lib/NonlinearSolveBase/src/utils.jl

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module Utils
22

33
using ArrayInterface: ArrayInterface
44
using FastClosures: @closure
5-
using LinearAlgebra: Symmetric, norm, dot
5+
using LinearAlgebra: LinearAlgebra, Diagonal, Symmetric, norm, dot, cond, diagind, pinv
6+
using MaybeInplace: @bb
67
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
78
using SciMLOperators: AbstractSciMLOperator
89
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NonlinearFunction
@@ -146,4 +147,46 @@ end
146147

147148
function make_sparse end
148149

150+
condition_number(J::AbstractMatrix) = cond(J)
151+
function condition_number(J::AbstractVector)
152+
if !ArrayInterface.can_setindex(J)
153+
J′ = similar(J)
154+
copyto!(J′, J)
155+
J = J′
156+
end
157+
return cond(Diagonal(J))
158+
end
159+
condition_number(::Any) = -1
160+
161+
# XXX: Move to NonlinearSolveQuasiNewton
162+
# compute `pinv` if `inv` won't work
163+
maybe_pinv!!_workspace(A) = nothing
164+
165+
maybe_pinv!!(workspace, A::Union{Number, AbstractMatrix}) = pinv(A)
166+
function maybe_pinv!!(workspace, A::Diagonal)
167+
D = A.diag
168+
@bb @. D = pinv(D)
169+
return Diagonal(D)
170+
end
171+
maybe_pinv!!(workspace, A::AbstractVector) = maybe_pinv!!(workspace, Diagonal(A))
172+
function maybe_pinv!!(workspace, A::StridedMatrix)
173+
LinearAlgebra.checksquare(A)
174+
if LinearAlgebra.istriu(A)
175+
issingular = any(iszero, @view(A[diagind(A)]))
176+
A_ = UpperTriangular(A)
177+
!issingular && return triu!(parent(inv(A_)))
178+
elseif LinearAlgebra.istril(A)
179+
A_ = LowerTriangular(A)
180+
issingular = any(iszero, @view(A_[diagind(A_)]))
181+
!issingular && return tril!(parent(inv(A_)))
182+
else
183+
F = LinearAlgebra.lu(A; check = false)
184+
if issuccess(F)
185+
Ai = LinearAlgebra.inv!(F)
186+
return convert(typeof(parent(Ai)), Ai)
187+
end
188+
end
189+
return pinv(A)
190+
end
191+
149192
end

src/NonlinearSolve.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@ using NonlinearSolveBase: NonlinearSolveBase,
2929
DescentResult,
3030
SteepestDescent, NewtonDescent, DampedNewtonDescent, Dogleg,
3131
GeodesicAcceleration,
32-
reset_timer!, @static_timeit
32+
reset_timer!, @static_timeit,
33+
init_nonlinearsolve_trace, update_trace!, reset!
3334

3435
# XXX: Remove
3536
import NonlinearSolveBase: InternalAPI, concrete_jac, supports_line_search,
3637
supports_trust_region, last_step_accepted, get_linear_solver,
3738
AbstractDampingFunction, AbstractDampingFunctionCache,
3839
requires_normal_form_jacobian, requires_normal_form_rhs,
39-
returns_norm_form_damping, get_timer_output
40+
returns_norm_form_damping, get_timer_output, get_u, get_fu
4041

4142
using Printf: @printf
4243
using Preferences: Preferences, set_preferences!
@@ -74,7 +75,6 @@ include("timer_outputs.jl")
7475
include("internal/helpers.jl")
7576

7677
include("internal/termination.jl")
77-
include("internal/tracing.jl")
7878
include("internal/approximate_initialization.jl")
7979

8080
include("globalization/line_search.jl")

0 commit comments

Comments
 (0)