Skip to content

Commit 00852f0

Browse files
committed
Fast General Klement Implementation
1 parent 1e4cfde commit 00852f0

File tree

7 files changed

+302
-12
lines changed

7 files changed

+302
-12
lines changed

docs/src/solvers/NonlinearSystemSolvers.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ features, but have a bit of overhead on very small problems.
6767
robustnes on the hard problems.
6868
- `GeneralBroyden()`: Generalization of Broyden's Quasi-Newton Method with Line Search and
6969
Automatic Jacobian Resetting. This is a fast method but unstable for most problems!
70+
- `GeneralKlement()`: Generalization of Klement's Quasi-Newton Method with Line Search and
71+
Automatic Jacobian Resetting. This is a fast method but unstable for most problems!
7072

7173
### SimpleNonlinearSolve.jl
7274

src/NonlinearSolve.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,13 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
5252
cache.stats.nsteps += 1
5353
end
5454

55-
if cache.stats.nsteps == cache.maxiters
56-
cache.retcode = ReturnCode.MaxIters
57-
else
58-
cache.retcode = ReturnCode.Success
55+
# The solver might have set a different `retcode`
56+
if cache.retcode == ReturnCode.Default
57+
if cache.stats.nsteps == cache.maxiters
58+
cache.retcode = ReturnCode.MaxIters
59+
else
60+
cache.retcode = ReturnCode.Success
61+
end
5962
end
6063

6164
return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, get_fu(cache);
@@ -85,7 +88,7 @@ import PrecompileTools
8588
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
8689

8790
precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
88-
nothing)
91+
PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing)
8992

9093
for alg in precompile_algs
9194
solve(prob, alg, abstol = T(1e-2))

src/broyden.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,7 @@ function perform_step!(cache::GeneralBroydenCache{false})
113113
cache.dfu = cache.fu2 .- cache.fu
114114
if cache.resets < cache.max_resets &&
115115
(all(x -> abs(x) 1e-12, cache.du) || all(x -> abs(x) 1e-12, cache.dfu))
116-
J⁻¹ = similar(cache.J⁻¹)
117-
fill!(J⁻¹, 0)
118-
J⁻¹[diagind(J⁻¹)] .= T(1)
119-
cache.J⁻¹ = J⁻¹
116+
cache.J⁻¹ = __init_identity_jacobian(cache.u, cache.fu)
120117
cache.resets += 1
121118
else
122119
cache.J⁻¹df = _restructure(cache.J⁻¹df, cache.J⁻¹ * _vec(cache.dfu))

src/default.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,8 @@ end
159159
]
160160
else
161161
[
162-
# FIXME: Broyden and Klement are type unstable
163-
# (upstream SimpleNonlinearSolve.jl issue)
164-
!iip ? :(Klement()) : nothing, # Klement not yet implemented for IIP
165162
:(GeneralBroyden()),
163+
:(GeneralKlement()),
166164
:(NewtonRaphson(; linsolve, precs, adkwargs...)),
167165
:(NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...)),
168166
:(TrustRegion(; linsolve, precs, adkwargs...)),

src/klement.jl

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,191 @@
1+
@concrete struct GeneralKlement <: AbstractNewtonAlgorithm{false, Nothing}
2+
max_resets::Int
3+
linsolve
4+
precs
5+
linesearch
6+
singular_tolerance
7+
end
18

9+
function GeneralKlement(; max_resets::Int = 5, linsolve = nothing,
10+
linesearch = LineSearch(), precs = DEFAULT_PRECS, singular_tolerance = nothing)
11+
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
12+
return GeneralKlement(max_resets, linsolve, precs, linesearch, singular_tolerance)
13+
end
14+
15+
@concrete mutable struct GeneralKlementCache{iip} <: AbstractNonlinearSolveCache{iip}
16+
f
17+
alg
18+
u
19+
fu
20+
fu2
21+
du
22+
p
23+
linsolve
24+
J
25+
J_cache
26+
J_cache2
27+
Jᵀ²du
28+
Jdu
29+
resets
30+
singular_tolerance
31+
force_stop
32+
maxiters::Int
33+
internalnorm
34+
retcode::ReturnCode.T
35+
abstol
36+
prob
37+
stats::NLStats
38+
lscache
39+
end
40+
41+
get_fu(cache::GeneralKlementCache) = cache.fu
42+
43+
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralKlement, args...;
44+
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
45+
linsolve_kwargs = (;), kwargs...) where {uType, iip}
46+
@unpack f, u0, p = prob
47+
u = alias_u0 ? u0 : deepcopy(u0)
48+
fu = evaluate_f(prob, u)
49+
J = __init_identity_jacobian(u, fu)
50+
51+
if u isa Number
52+
linsolve = nothing
53+
else
54+
weight = similar(u)
55+
recursivefill!(weight, true)
56+
Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
57+
nothing)..., weight)
58+
linprob = LinearProblem(J, _vec(fu); u0 = _vec(fu))
59+
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
60+
linsolve_kwargs...)
61+
end
62+
63+
singular_tolerance = alg.singular_tolerance === nothing ? inv(sqrt(eps(eltype(u)))) :
64+
eltype(u)(alg.singular_tolerance)
65+
66+
return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), _mutable_zero(u), p, linsolve,
67+
J, zero(J), zero(J), zero(fu), zero(fu), 0, singular_tolerance, false,
68+
maxiters, internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0),
69+
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
70+
end
71+
72+
function perform_step!(cache::GeneralKlementCache{true})
73+
@unpack u, fu, f, p, alg, J, linsolve, du = cache
74+
T = eltype(J)
75+
76+
# FIXME: How can we do this faster?
77+
if cond(J) > cache.singular_tolerance
78+
if cache.resets == alg.max_resets
79+
cache.force_stop = true
80+
cache.retcode = ReturnCode.Unstable
81+
return nothing
82+
end
83+
fill!(J, zero(T))
84+
J[diagind(J)] .= T(1)
85+
cache.resets += 1
86+
end
87+
88+
# u = u - J \ fu
89+
linres = dolinsolve(alg.precs, linsolve; A = J, b = -_vec(fu), linu = _vec(du),
90+
p, reltol = cache.abstol)
91+
cache.linsolve = linres.cache
92+
93+
# Line Search
94+
α = perform_linesearch!(cache.lscache, u, du)
95+
axpy!(α, du, u)
96+
f(cache.fu2, u, p)
97+
98+
cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
99+
cache.stats.nf += 1
100+
cache.stats.nsolve += 1
101+
cache.stats.nfactors += 1
102+
103+
cache.force_stop && return nothing
104+
105+
# Update the Jacobian
106+
cache.J_cache .= cache.J' .^ 2
107+
cache.Jdu .= _vec(du) .^ 2
108+
mul!(cache.Jᵀ²du, cache.J_cache, cache.Jdu)
109+
mul!(cache.Jdu, J, _vec(du))
110+
cache.fu .= cache.fu2 .- cache.fu
111+
cache.fu .= (cache.fu .- _restructure(cache.fu, cache.Jdu)) ./ max.(cache.Jᵀ²du, eps(T))
112+
mul!(cache.J_cache, _vec(cache.fu), _vec(du)')
113+
cache.J_cache .*= J
114+
mul!(cache.J_cache2, cache.J_cache, J)
115+
J .+= cache.J_cache2
116+
117+
cache.fu .= cache.fu2
118+
119+
return nothing
120+
end
121+
122+
function perform_step!(cache::GeneralKlementCache{false})
123+
@unpack fu, f, p, alg, J, linsolve = cache
124+
T = eltype(J)
125+
126+
# FIXME: How can we do this faster?
127+
if cond(J) > cache.singular_tolerance
128+
if cache.resets == alg.max_resets
129+
cache.force_stop = true
130+
cache.retcode = ReturnCode.Unstable
131+
return nothing
132+
end
133+
cache.J = __init_identity_jacobian(u, fu)
134+
cache.resets += 1
135+
end
136+
137+
# u = u - J \ fu
138+
if linsolve === nothing
139+
cache.du = -fu / cache.J
140+
else
141+
linres = dolinsolve(alg.precs, linsolve; A = J, b = -_vec(fu),
142+
linu = _vec(cache.du), p, reltol = cache.abstol)
143+
cache.linsolve = linres.cache
144+
end
145+
146+
# Line Search
147+
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
148+
cache.u = @. cache.u + α * cache.du # `u` might not support mutation
149+
cache.fu2 = f(cache.u, p)
150+
151+
cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
152+
cache.stats.nf += 1
153+
cache.stats.nsolve += 1
154+
cache.stats.nfactors += 1
155+
156+
cache.force_stop && return nothing
157+
158+
# Update the Jacobian
159+
cache.J_cache = cache.J' .^ 2
160+
cache.Jdu = _vec(cache.du) .^ 2
161+
cache.Jᵀ²du = cache.J_cache * cache.Jdu
162+
cache.Jdu = J * _vec(cache.du)
163+
cache.fu = cache.fu2 .- cache.fu
164+
cache.fu = (cache.fu .- _restructure(cache.fu, cache.Jdu)) ./ max.(cache.Jᵀ²du, eps(T))
165+
cache.J_cache = ((_vec(cache.fu) * _vec(cache.du)') .* J) * J
166+
cache.J = J .+ cache.J_cache
167+
168+
cache.fu = cache.fu2
169+
170+
return nothing
171+
end
172+
173+
function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = cache.p,
174+
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
175+
cache.p = p
176+
if iip
177+
recursivecopy!(cache.u, u0)
178+
cache.f(cache.fu, cache.u, p)
179+
else
180+
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
181+
cache.u = u0
182+
cache.fu = cache.f(cache.u, p)
183+
end
184+
cache.abstol = abstol
185+
cache.maxiters = maxiters
186+
cache.stats.nf = 1
187+
cache.stats.nsteps = 1
188+
cache.force_stop = false
189+
cache.retcode = ReturnCode.Default
190+
return cache
191+
end

test/23_test_problems.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,14 @@ end
9292

9393
test_on_library(problems, dicts, alg_ops, broken_tests)
9494
end
95+
96+
@testset "GeneralKlement 23 Test Problems" begin
97+
alg_ops = (GeneralKlement(),
98+
GeneralKlement(; linesearch = BackTracking()))
99+
100+
broken_tests = Dict(alg => Int[] for alg in alg_ops)
101+
broken_tests[alg_ops[1]] = [1, 2, 3, 4, 5, 6, 7, 13, 22]
102+
broken_tests[alg_ops[2]] = [1, 2, 4, 5, 6, 7, 11, 12, 22]
103+
104+
test_on_library(problems, dicts, alg_ops, broken_tests)
105+
end

test/basictests.jl

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,3 +754,92 @@ end
754754
@test nlprob_iterator_interface(quadratic_f, p, Val(false)) sqrt.(p)
755755
@test nlprob_iterator_interface(quadratic_f!, p, Val(true)) sqrt.(p)
756756
end
757+
758+
# --- GeneralKlement tests ---
759+
760+
@testset "GeneralKlement" begin
761+
function benchmark_nlsolve_oop(f, u0, p = 2.0; linesearch = LineSearch())
762+
prob = NonlinearProblem{false}(f, u0, p)
763+
return solve(prob, GeneralKlement(; linesearch), abstol = 1e-9)
764+
end
765+
766+
function benchmark_nlsolve_iip(f, u0, p = 2.0; linesearch = LineSearch())
767+
prob = NonlinearProblem{true}(f, u0, p)
768+
return solve(prob, GeneralKlement(; linesearch), abstol = 1e-9)
769+
end
770+
771+
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (Static(),
772+
StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()),
773+
ad in (AutoFiniteDiff(), AutoZygote())
774+
775+
linesearch = LineSearch(; method = lsmethod, autodiff = ad)
776+
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
777+
778+
@testset "[OOP] u0: $(typeof(u0))" for u0 in u0s
779+
sol = benchmark_nlsolve_oop(quadratic_f, u0; linesearch)
780+
@test SciMLBase.successful_retcode(sol)
781+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
782+
783+
cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0),
784+
GeneralKlement(; linesearch), abstol = 1e-9)
785+
@test (@ballocated solve!($cache)) < 200
786+
end
787+
788+
@testset "[IIP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0],)
789+
ad isa AutoZygote && continue
790+
sol = benchmark_nlsolve_iip(quadratic_f!, u0; linesearch)
791+
@test SciMLBase.successful_retcode(sol)
792+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
793+
794+
cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0),
795+
GeneralKlement(; linesearch), abstol = 1e-9)
796+
@test (@ballocated solve!($cache)) 64
797+
end
798+
end
799+
800+
@testset "[OOP] [Immutable AD]" begin
801+
for p in 1.0:0.1:100.0
802+
@test begin
803+
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
804+
res_true = sqrt(p)
805+
all(res.u .≈ res_true)
806+
end
807+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
808+
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
809+
end
810+
end
811+
812+
@testset "[OOP] [Scalar AD]" begin
813+
for p in 1.0:0.1:100.0
814+
@test begin
815+
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
816+
res_true = sqrt(p)
817+
res.u res_true
818+
end
819+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
820+
p) 1 / (2 * sqrt(p))
821+
end
822+
end
823+
824+
t = (p) -> [sqrt(p[2] / p[1])]
825+
p = [0.9, 50.0]
826+
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
827+
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
828+
p) ForwardDiff.jacobian(t, p)
829+
830+
# Iterator interface
831+
function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip}
832+
probN = NonlinearProblem{iip}(f, iip ? [0.5] : 0.5, p_range[begin])
833+
cache = init(probN, GeneralKlement(); maxiters = 100, abstol = 1e-10)
834+
sols = zeros(length(p_range))
835+
for (i, p) in enumerate(p_range)
836+
reinit!(cache, iip ? [cache.u[1]] : cache.u; p = p)
837+
sol = solve!(cache)
838+
sols[i] = iip ? sol.u[1] : sol.u
839+
end
840+
return sols
841+
end
842+
p = range(0.01, 2, length = 200)
843+
@test nlprob_iterator_interface(quadratic_f, p, Val(false)) sqrt.(p)
844+
@test nlprob_iterator_interface(quadratic_f!, p, Val(true)) sqrt.(p)
845+
end

0 commit comments

Comments
 (0)