Skip to content

Commit 0dd6194

Browse files
committed
Add true jacobian init for broyden and klement
1 parent 021901e commit 0dd6194

File tree

6 files changed

+172
-93
lines changed

6 files changed

+172
-93
lines changed

src/broyden.jl

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Sadly `Broyden` is taken up by SimpleNonlinearSolve.jl
22
"""
3-
GeneralBroyden(; max_resets = 3, linesearch = nothing, reset_tolerance = nothing)
3+
GeneralBroyden(; max_resets = 3, linesearch = nothing, reset_tolerance = nothing,
4+
init_jacobian::Val = Val(:identity), autodiff = nothing)
45
56
An implementation of `Broyden` with resetting and line search.
67
@@ -14,20 +15,36 @@ An implementation of `Broyden` with resetting and line search.
1415
used here directly, and they will be converted to the correct `LineSearch`. It is
1516
recommended to use [LiFukushimaLineSearch](@ref) -- a derivative free linesearch
1617
specifically designed for Broyden's method.
18+
- `init_jacobian`: the method to use for initializing the jacobian. Defaults to using the
19+
identity matrix (`Val(:identitiy)`). Alternatively, can be set to `Val(:true_jacobian)`
20+
to use the true jacobian as initialization.
21+
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
22+
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
23+
`nothing` which means that a default is selected according to the problem specification!
24+
Valid choices are types from ADTypes.jl. (Used if `init_jacobian = Val(:true_jacobian)`)
1725
"""
18-
@concrete struct GeneralBroyden <: AbstractNewtonAlgorithm{false, Nothing}
26+
@concrete struct GeneralBroyden{IJ, CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
27+
ad::AD
1928
max_resets::Int
2029
reset_tolerance
2130
linesearch
2231
end
2332

33+
function set_ad(alg::GeneralBroyden{IJ, CJ}, ad) where {IJ, CJ}
34+
return GeneralBroyden{IJ, CJ}(ad, alg.max_resets, alg.reset_tolerance, alg.linesearch)
35+
end
36+
2437
function GeneralBroyden(; max_resets = 3, linesearch = nothing,
25-
reset_tolerance = nothing)
38+
reset_tolerance = nothing, init_jacobian::Val = Val(:identity),
39+
autodiff = nothing)
40+
IJ = _unwrap_val(init_jacobian)
41+
@assert IJ (:identity, :true_jacobian)
2642
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
27-
return GeneralBroyden(max_resets, reset_tolerance, linesearch)
43+
CJ = IJ === :true_jacobian
44+
return GeneralBroyden{IJ, CJ}(autodiff, max_resets, reset_tolerance, linesearch)
2845
end
2946

30-
@concrete mutable struct GeneralBroydenCache{iip} <: AbstractNonlinearSolveCache{iip}
47+
@concrete mutable struct GeneralBroydenCache{iip, IJ} <: AbstractNonlinearSolveCache{iip}
3148
f
3249
alg
3350
u
@@ -37,6 +54,7 @@ end
3754
fu_cache
3855
dfu
3956
p
57+
uf
4058
J⁻¹
4159
J⁻¹dfu
4260
force_stop::Bool
@@ -49,45 +67,62 @@ end
4967
reltol
5068
reset_tolerance
5169
reset_check
70+
jac_cache
5271
prob
5372
stats::NLStats
5473
ls_cache
5574
tc_cache
5675
trace
5776
end
5877

59-
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...;
60-
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
78+
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralBroyden{IJ},
79+
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
6180
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
62-
kwargs...) where {uType, iip, F}
81+
kwargs...) where {uType, iip, F, IJ}
6382
@unpack f, u0, p = prob
6483
u = __maybe_unaliased(u0, alias_u0)
6584
fu = evaluate_f(prob, u)
6685
@bb du = copy(u)
67-
J⁻¹ = __init_identity_jacobian(u, fu)
86+
87+
if IJ === :true_jacobian
88+
alg = get_concrete_algorithm(alg_, prob)
89+
uf, _, J, fu_cache, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
90+
lininit = Val(false))
91+
J⁻¹ = J
92+
else
93+
alg = alg_
94+
@bb du = similar(u)
95+
uf, fu_cache, jac_cache = nothing, nothing, nothing
96+
J⁻¹ = __init_identity_jacobian(u, fu)
97+
end
98+
6899
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(real(eltype(u)))) :
69100
alg.reset_tolerance
70101
reset_check = x -> abs(x) reset_tolerance
71102

72103
@bb u_cache = copy(u)
73-
@bb fu_cache = copy(fu)
74-
@bb dfu = similar(fu)
104+
@bb dfu = copy(fu)
75105
@bb J⁻¹dfu = similar(u)
76106

77107
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
78108
termination_condition)
79109
trace = init_nonlinearsolve_trace(alg, u, fu, J⁻¹, du; uses_jac_inverse = Val(true),
80110
kwargs...)
81111

82-
return GeneralBroydenCache{iip}(f, alg, u, u_cache, du, fu, fu_cache, dfu, p,
112+
return GeneralBroydenCache{iip, IJ}(f, alg, u, u_cache, du, fu, fu_cache, dfu, p, uf,
83113
J⁻¹, J⁻¹dfu, false, 0, alg.max_resets, maxiters, internalnorm, ReturnCode.Default,
84-
abstol, reltol, reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
114+
abstol, reltol, reset_tolerance, reset_check, jac_cache, prob,
115+
NLStats(1, 0, 0, 0, 0),
85116
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
86117
end
87118

88-
function perform_step!(cache::GeneralBroydenCache{iip}) where {iip}
119+
function perform_step!(cache::GeneralBroydenCache{iip, IJ}) where {iip, IJ}
89120
T = eltype(cache.u)
90121

122+
if IJ === :true_jacobian && cache.stats.nsteps == 0
123+
cache.J⁻¹ = inv(jacobian!!(cache.J⁻¹, cache)) # This allocates
124+
end
125+
91126
@bb cache.du = cache.J⁻¹ × vec(cache.fu)
92127
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
93128
@bb axpy!(-α, cache.du, cache.u)
@@ -100,15 +135,19 @@ function perform_step!(cache::GeneralBroydenCache{iip}) where {iip}
100135
cache.force_stop && return nothing
101136

102137
# Update the inverse jacobian
103-
@bb @. cache.dfu = cache.fu - cache.fu_cache
138+
@bb @. cache.dfu = cache.fu - cache.dfu
104139

105140
if all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu)
106141
if cache.resets cache.max_resets
107142
cache.retcode = ReturnCode.ConvergenceFailure
108143
cache.force_stop = true
109144
return nothing
110145
end
111-
cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹)
146+
if IJ === :true_jacobian
147+
cache.J⁻¹ = inv(jacobian!!(cache.J⁻¹, cache))
148+
else
149+
cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹)
150+
end
112151
cache.resets += 1
113152
else
114153
@bb cache.du .*= -1
@@ -119,7 +158,7 @@ function perform_step!(cache::GeneralBroydenCache{iip}) where {iip}
119158
@bb cache.J⁻¹ += vec(cache.du) × transpose(_vec(cache.u_cache))
120159
end
121160

122-
@bb copyto!(cache.fu_cache, cache.fu)
161+
@bb copyto!(cache.dfu, cache.fu)
123162
@bb copyto!(cache.u_cache, cache.u)
124163

125164
return nothing

src/jacobian.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ function jacobian!!(::Number, cache)
5858
cache.stats.njacs += 1
5959
return last(value_derivative(cache.uf, cache.u))
6060
end
61+
6162
# Build Jacobian Caches
6263
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val{iip};
6364
linsolve_kwargs = (;), lininit::Val{linsolve_init} = Val(true),

src/klement.jl

Lines changed: 100 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -19,33 +19,54 @@ solves.
1919
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
2020
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
2121
used here directly, and they will be converted to the correct `LineSearch`.
22+
- `init_jacobian`: the method to use for initializing the jacobian. Defaults to using the
23+
identity matrix (`Val(:identitiy)`). Alternatively, can be set to `Val(:true_jacobian)`
24+
to use the true jacobian as initialization. (Our tests suggest it is a good idea to
25+
to initialize with an identity matrix)
26+
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
27+
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
28+
`nothing` which means that a default is selected according to the problem specification!
29+
Valid choices are types from ADTypes.jl. (Used if `init_jacobian = Val(:true_jacobian)`)
2230
"""
23-
@concrete struct GeneralKlement <: AbstractNewtonAlgorithm{false, Nothing}
31+
@concrete struct GeneralKlement{IJ, CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
32+
ad::AD
2433
max_resets::Int
2534
linsolve
2635
precs
2736
linesearch
2837
end
2938

30-
function set_linsolve(alg::GeneralKlement, linsolve)
31-
return GeneralKlement(alg.max_resets, linsolve, alg.precs, alg.linesearch)
39+
function set_linsolve(alg::GeneralKlement{IJ, CS}, linsolve) where {IJ, CS}
40+
return GeneralKlement{IJ, CS}(alg.ad, alg.max_resets, linsolve, alg.precs,
41+
alg.linesearch)
42+
end
43+
44+
function set_ad(alg::GeneralKlement{IJ, CS}, ad) where {IJ, CS}
45+
return GeneralKlement{IJ, CS}(ad, alg.max_resets, alg.linsolve, alg.precs,
46+
alg.linesearch)
3247
end
3348

3449
function GeneralKlement(; max_resets::Int = 5, linsolve = nothing,
35-
linesearch = nothing, precs = DEFAULT_PRECS)
50+
linesearch = nothing, precs = DEFAULT_PRECS, init_jacobian::Val = Val(:identity),
51+
autodiff = nothing)
52+
IJ = _unwrap_val(init_jacobian)
53+
@assert IJ (:identity, :true_jacobian)
3654
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
37-
return GeneralKlement(max_resets, linsolve, precs, linesearch)
55+
CJ = IJ === :true_jacobian
56+
return GeneralKlement{IJ, CJ}(autodiff, max_resets, linsolve, precs, linesearch)
3857
end
3958

40-
@concrete mutable struct GeneralKlementCache{iip} <: AbstractNonlinearSolveCache{iip}
59+
@concrete mutable struct GeneralKlementCache{iip, IJ} <: AbstractNonlinearSolveCache{iip}
4160
f
4261
alg
4362
u
4463
u_cache
4564
fu
4665
fu_cache
66+
fu_cache_2
4767
du
4868
p
69+
uf
4970
linsolve
5071
J
5172
J_cache
@@ -60,73 +81,96 @@ end
6081
abstol
6182
reltol
6283
prob
84+
jac_cache
6385
stats::NLStats
6486
ls_cache
6587
tc_cache
6688
trace
6789
end
6890

69-
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKlement, args...;
70-
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
91+
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKlement{IJ},
92+
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
7193
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
72-
linsolve_kwargs = (;), kwargs...) where {uType, iip, F}
94+
linsolve_kwargs = (;), kwargs...) where {uType, iip, F, IJ}
7395
@unpack f, u0, p = prob
7496
u = __maybe_unaliased(u0, alias_u0)
7597
fu = evaluate_f(prob, u)
76-
J = __init_identity_jacobian(u, fu)
77-
@bb du = similar(u)
7898

79-
if u isa Number
80-
linsolve = FakeLinearSolveJLCache(J, fu)
99+
if IJ === :true_jacobian
100+
alg = get_concrete_algorithm(alg_, prob)
101+
uf, _, J, fu_cache, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
102+
lininit = Val(false))
103+
elseif IJ === :identity
81104
alg = alg_
105+
@bb du = similar(u)
106+
uf, fu_cache, jac_cache = nothing, nothing, nothing
107+
J = one.(u) # Identity Init Jacobian for Klement maintains a Diagonal Structure
108+
else
109+
error("Invalid `init_jacobian` value")
110+
end
111+
112+
if IJ === :true_jacobian
113+
linsolve = linsolve_caches(J, _vec(fu), _vec(du), p, alg_; linsolve_kwargs)
82114
else
83-
# For General Julia Arrays default to LU Factorization
84-
linsolve_alg = (alg_.linsolve === nothing && (u isa Array || u isa StaticArray)) ?
85-
LUFactorization() : nothing
86-
alg = set_linsolve(alg_, linsolve_alg)
87-
linsolve = linsolve_caches(J, _vec(fu), _vec(du), p, alg; linsolve_kwargs)
115+
linsolve = nothing
88116
end
89117

90118
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
91119
termination_condition)
92120
trace = init_nonlinearsolve_trace(alg, u, fu, J, du; kwargs...)
93121

94122
@bb u_cache = copy(u)
95-
@bb fu_cache = copy(fu)
96-
@bb J_cache = similar(J)
97-
@bb J_cache_2 = similar(J)
123+
@bb fu_cache_2 = copy(fu)
98124
@bb Jdu = similar(fu)
99-
@bb Jdu_cache = similar(fu)
125+
if IJ === :true_jacobian
126+
@bb J_cache = similar(J)
127+
@bb J_cache_2 = similar(J)
128+
@bb Jdu_cache = similar(fu)
129+
else
130+
J_cache, J_cache_2, Jdu_cache = nothing, nothing, nothing
131+
end
100132

101-
return GeneralKlementCache{iip}(f, alg, u, u_cache, fu, fu_cache, du, p, linsolve,
102-
J, J_cache, J_cache_2, Jdu, Jdu_cache, 0, false, maxiters, internalnorm,
103-
ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0),
133+
return GeneralKlementCache{iip, IJ}(f, alg, u, u_cache, fu, fu_cache, fu_cache_2, du, p,
134+
uf, linsolve, J, J_cache, J_cache_2, Jdu, Jdu_cache, 0, false, maxiters,
135+
internalnorm,
136+
ReturnCode.Default, abstol, reltol, prob, jac_cache, NLStats(1, 0, 0, 0, 0),
104137
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
105138
end
106139

107-
function perform_step!(cache::GeneralKlementCache{iip}) where {iip}
140+
function perform_step!(cache::GeneralKlementCache{iip, IJ}) where {iip, IJ}
108141
@unpack linsolve, alg = cache
109142
T = eltype(cache.J)
110-
singular, fact_done = __try_factorize_and_check_singular!(linsolve, cache.J)
111143

112-
if singular
144+
if IJ === :true_jacobian
145+
cache.stats.nsteps == 0 && (cache.J = jacobian!!(cache.J, cache))
146+
ill_conditioned = __is_ill_conditioned(cache.J)
147+
elseif IJ === :identity
148+
ill_conditioned = __is_ill_conditioned(cache.J)
149+
end
150+
151+
if ill_conditioned
113152
if cache.resets == alg.max_resets
114153
cache.force_stop = true
115154
cache.retcode = ReturnCode.ConvergenceFailure
116155
return nothing
117156
end
118-
fact_done = false
119-
cache.J = __reinit_identity_jacobian!!(cache.J)
157+
if IJ === :true_jacobian && cache.stats.nsteps != 0
158+
cache.J = jacobian!!(cache.J, cache)
159+
else
160+
cache.J = __reinit_identity_jacobian!!(cache.J)
161+
end
120162
cache.resets += 1
121163
end
122164

123-
A = ifelse(cache.J isa SMatrix || cache.J isa Number || !fact_done, cache.J, nothing)
124-
125-
# u = u - J \ fu
126-
linres = dolinsolve(cache, alg.precs, cache.linsolve; A,
127-
b = _vec(cache.fu), linu = _vec(cache.du), cache.p, reltol = cache.abstol)
128-
cache.linsolve = linres.cache
129-
cache.du = _restructure(cache.du, linres.u)
165+
if IJ === :identity
166+
@bb @. cache.du = cache.fu / cache.J
167+
else
168+
# u = u - J \ fu
169+
linres = dolinsolve(cache, alg.precs, cache.linsolve; A = cache.J,
170+
b = _vec(cache.fu), linu = _vec(cache.du), cache.p, reltol = cache.abstol)
171+
cache.linsolve = linres.cache
172+
cache.du = _restructure(cache.du, linres.u)
173+
end
130174

131175
# Line Search
132176
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
@@ -143,18 +187,25 @@ function perform_step!(cache::GeneralKlementCache{iip}) where {iip}
143187

144188
# Update the Jacobian
145189
@bb cache.du .*= -1
146-
@bb cache.J_cache .= cache.J' .^ 2
147-
@bb @. cache.Jdu = cache.du^2
148-
@bb cache.Jdu_cache = cache.J_cache × vec(cache.Jdu)
149-
@bb cache.Jdu = cache.J × vec(cache.du)
150-
@bb @. cache.fu_cache = (cache.fu - cache.fu_cache - cache.Jdu) /
151-
ifelse(iszero(cache.Jdu_cache), T(1e-5), cache.Jdu_cache)
152-
@bb cache.J_cache = vec(cache.fu_cache) × transpose(_vec(cache.du))
153-
@bb @. cache.J_cache *= cache.J
154-
@bb cache.J_cache_2 = cache.J_cache × cache.J
155-
@bb cache.J .+= cache.J_cache_2
156-
157-
@bb copyto!(cache.fu_cache, cache.fu)
190+
if IJ === :identity
191+
@bb @. cache.Jdu = (cache.J^2) * (cache.du^2)
192+
@bb @. cache.J += ((cache.fu - cache.fu_cache_2 - cache.J * cache.du) /
193+
ifelse(iszero(cache.Jdu), T(1e-5), cache.Jdu)) * cache.du *
194+
(cache.J^2)
195+
else
196+
@bb cache.J_cache .= cache.J' .^ 2
197+
@bb @. cache.Jdu = cache.du^2
198+
@bb cache.Jdu_cache = cache.J_cache × vec(cache.Jdu)
199+
@bb cache.Jdu = cache.J × vec(cache.du)
200+
@bb @. cache.fu_cache_2 = (cache.fu - cache.fu_cache_2 - cache.Jdu) /
201+
ifelse(iszero(cache.Jdu_cache), T(1e-5), cache.Jdu_cache)
202+
@bb cache.J_cache = vec(cache.fu_cache_2) × transpose(_vec(cache.du))
203+
@bb @. cache.J_cache *= cache.J
204+
@bb cache.J_cache_2 = cache.J_cache × cache.J
205+
@bb cache.J .+= cache.J_cache_2
206+
end
207+
208+
@bb copyto!(cache.fu_cache_2, cache.fu)
158209

159210
return nothing
160211
end

0 commit comments

Comments
 (0)