Skip to content

Commit f20c9bc

Browse files
committed
Move other algorithms to use termination conditions
1 parent 416e656 commit f20c9bc

File tree

4 files changed

+121
-28
lines changed

4 files changed

+121
-28
lines changed

src/broyden.jl

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ end
3131
f
3232
alg
3333
u
34+
u_prev
3435
du
3536
fu
3637
fu2
@@ -46,17 +47,21 @@ end
4647
internalnorm
4748
retcode::ReturnCode.T
4849
abstol
50+
reltol
4951
reset_tolerance
5052
reset_check
5153
prob
5254
stats::NLStats
5355
lscache
56+
termination_condition
57+
tc_storage
5458
end
5559

5660
get_fu(cache::GeneralBroydenCache) = cache.fu
5761

5862
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...;
59-
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
63+
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
64+
termination_condition = nothing, internalnorm = DEFAULT_NORM,
6065
kwargs...) where {uType, iip}
6166
@unpack f, u0, p = prob
6267
u = alias_u0 ? u0 : deepcopy(u0)
@@ -65,23 +70,38 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
6570
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(eltype(u))) :
6671
alg.reset_tolerance
6772
reset_check = x -> abs(x) reset_tolerance
68-
return GeneralBroydenCache{iip}(f, alg, u, _mutable_zero(u), fu, zero(fu),
73+
74+
abstol, reltol, termination_condition = _init_termination_elements(abstol,
75+
reltol,
76+
termination_condition,
77+
eltype(u))
78+
79+
mode = DiffEqBase.get_termination_mode(termination_condition)
80+
81+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
82+
nothing
83+
return GeneralBroydenCache{iip}(f, alg, u, zero(u), _mutable_zero(u), fu, zero(fu),
6984
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
70-
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance,
85+
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
86+
reset_tolerance,
7187
reset_check, prob, NLStats(1, 0, 0, 0, 0),
72-
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
88+
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition,
89+
storage)
7390
end
7491

7592
function perform_step!(cache::GeneralBroydenCache{true})
76-
@unpack f, p, du, fu, fu2, dfu, u, J⁻¹, J⁻¹df, J⁻¹₂ = cache
93+
@unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂, tc_storage = cache
94+
95+
termination_condition = cache.termination_condition(tc_storage)
7796
T = eltype(u)
7897

7998
mul!(_vec(du), J⁻¹, -_vec(fu))
8099
α = perform_linesearch!(cache.lscache, u, du)
81100
_axpy!(α, du, u)
82101
f(fu2, u, p)
83102

84-
cache.internalnorm(fu2) < cache.abstol && (cache.force_stop = true)
103+
termination_condition(fu2, u, u_prev, cache.abstol, cache.reltol) &&
104+
(cache.force_stop = true)
85105
cache.stats.nf += 1
86106

87107
cache.force_stop && return nothing
@@ -106,20 +126,25 @@ function perform_step!(cache::GeneralBroydenCache{true})
106126
mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1)
107127
end
108128
fu .= fu2
129+
@. u_prev = u
109130

110131
return nothing
111132
end
112133

113134
function perform_step!(cache::GeneralBroydenCache{false})
114-
@unpack f, p = cache
135+
@unpack f, p, tc_storage = cache
136+
137+
termination_condition = cache.termination_condition(tc_storage)
138+
115139
T = eltype(cache.u)
116140

117141
cache.du = _restructure(cache.du, cache.J⁻¹ * -_vec(cache.fu))
118142
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
119143
cache.u = cache.u .+ α * cache.du
120144
cache.fu2 = f(cache.u, p)
121145

122-
cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
146+
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
147+
(cache.force_stop = true)
123148
cache.stats.nf += 1
124149

125150
cache.force_stop && return nothing
@@ -142,12 +167,15 @@ function perform_step!(cache::GeneralBroydenCache{false})
142167
cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂
143168
end
144169
cache.fu = cache.fu2
170+
cache.u_prev = @. cache.u
145171

146172
return nothing
147173
end
148174

149175
function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p,
150-
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
176+
abstol = cache.abstol, reltol = cache.reltol,
177+
termination_condition = cache.termination_condition,
178+
maxiters = cache.maxiters) where {iip}
151179
cache.p = p
152180
if iip
153181
recursivecopy!(cache.u, u0)
@@ -157,7 +185,14 @@ function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = ca
157185
cache.u = u0
158186
cache.fu = cache.f(cache.u, p)
159187
end
188+
termination_condition = _get_reinit_termination_condition(cache,
189+
abstol,
190+
reltol,
191+
termination_condition)
192+
160193
cache.abstol = abstol
194+
cache.reltol = reltol
195+
cache.termination_condition = termination_condition
161196
cache.maxiters = maxiters
162197
cache.stats.nf = 1
163198
cache.stats.nsteps = 1

src/klement.jl

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ end
4141
f
4242
alg
4343
u
44+
u_prev
4445
fu
4546
fu2
4647
du
@@ -65,7 +66,8 @@ end
6566
get_fu(cache::GeneralKlementCache) = cache.fu
6667

6768
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKlement, args...;
68-
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
69+
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
70+
termination_condition = nothing, internalnorm = DEFAULT_NORM,
6971
linsolve_kwargs = (;), kwargs...) where {uType, iip}
7072
@unpack f, u0, p = prob
7173
u = alias_u0 ? u0 : deepcopy(u0)
@@ -84,16 +86,30 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKleme
8486
linsolve = __setup_linsolve(J, _vec(fu), _vec(du), p, alg)
8587
end
8688

87-
return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), du, p, linsolve,
89+
abstol, reltol, termination_condition = _init_termination_elements(abstol,
90+
reltol,
91+
termination_condition,
92+
eltype(u))
93+
94+
mode = DiffEqBase.get_termination_mode(termination_condition)
95+
96+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
97+
nothing
98+
99+
return GeneralKlementCache{iip}(f, alg, u, zero(u), fu, zero(fu), du, p, linsolve,
88100
J, zero(J), zero(J), _vec(zero(fu)), _vec(zero(fu)), 0, false,
89-
maxiters, internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0),
90-
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
101+
maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
102+
NLStats(1, 0, 0, 0, 0),
103+
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition,
104+
storage)
91105
end
92106

93107
function perform_step!(cache::GeneralKlementCache{true})
94-
@unpack u, fu, f, p, alg, J, linsolve, du = cache
108+
@unpack u, u_prev, fu, f, p, alg, J, linsolve, du, tc_storage = cache
95109
T = eltype(J)
96110

111+
termination_condition = cache.termination_condition(tc_storage)
112+
97113
singular, fact_done = _try_factorize_and_check_singular!(linsolve, J)
98114

99115
if singular
@@ -118,7 +134,8 @@ function perform_step!(cache::GeneralKlementCache{true})
118134
_axpy!(α, du, u)
119135
f(cache.fu2, u, p)
120136

121-
cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
137+
termination_condition(cache.fu2, u, u_prev, cache.abstol, cache.reltol) &&
138+
(cache.force_stop = true)
122139
cache.stats.nf += 1
123140
cache.stats.nsolve += 1
124141
cache.stats.nfactors += 1
@@ -138,13 +155,17 @@ function perform_step!(cache::GeneralKlementCache{true})
138155
mul!(cache.J_cache2, cache.J_cache, J)
139156
J .+= cache.J_cache2
140157

158+
@. u_prev = u
141159
cache.fu .= cache.fu2
142160

143161
return nothing
144162
end
145163

146164
function perform_step!(cache::GeneralKlementCache{false})
147-
@unpack fu, f, p, alg, J, linsolve = cache
165+
@unpack fu, f, p, alg, J, linsolve, tc_storage = cache
166+
167+
termination_condition = cache.termination_condition(tc_storage)
168+
148169
T = eltype(J)
149170

150171
singular, fact_done = _try_factorize_and_check_singular!(linsolve, J)
@@ -174,7 +195,10 @@ function perform_step!(cache::GeneralKlementCache{false})
174195
cache.u = @. cache.u + α * cache.du # `u` might not support mutation
175196
cache.fu2 = f(cache.u, p)
176197

177-
cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
198+
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
199+
(cache.force_stop = true)
200+
201+
cache.u_prev = @. cache.u
178202
cache.stats.nf += 1
179203
cache.stats.nsolve += 1
180204
cache.stats.nfactors += 1
@@ -198,7 +222,9 @@ function perform_step!(cache::GeneralKlementCache{false})
198222
end
199223

200224
function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = cache.p,
201-
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
225+
abstol = cache.abstol, reltol = cache.reltol,
226+
termination_condition = cache.termination_condition,
227+
maxiters = cache.maxiters) where {iip}
202228
cache.p = p
203229
if iip
204230
recursivecopy!(cache.u, u0)
@@ -208,7 +234,14 @@ function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = ca
208234
cache.u = u0
209235
cache.fu = cache.f(cache.u, p)
210236
end
237+
238+
termination_condition = _get_reinit_termination_condition(cache,
239+
abstol,
240+
reltol,
241+
termination_condition)
211242
cache.abstol = abstol
243+
cache.reltol = reltol
244+
cache.termination_condition = termination_condition
212245
cache.maxiters = maxiters
213246
cache.stats.nf = 1
214247
cache.stats.nsteps = 1

src/lbroyden.jl

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ end
3434
f
3535
alg
3636
u
37+
u_prev
3738
du
3839
fu
3940
fu2
@@ -53,17 +54,21 @@ end
5354
internalnorm
5455
retcode::ReturnCode.T
5556
abstol
57+
reltol
5658
reset_tolerance
5759
reset_check
5860
prob
5961
stats::NLStats
6062
lscache
63+
termination_condition
64+
tc_storage
6165
end
6266

6367
get_fu(cache::LimitedMemoryBroydenCache) = cache.fu
6468

6569
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LimitedMemoryBroyden,
66-
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
70+
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
71+
termination_condition = nothing, internalnorm = DEFAULT_NORM,
6772
kwargs...) where {uType, iip}
6873
@unpack f, u0, p = prob
6974
u = alias_u0 ? u0 : deepcopy(u0)
@@ -80,23 +85,38 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LimitedMemory
8085
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(eltype(u))) :
8186
alg.reset_tolerance
8287
reset_check = x -> abs(x) reset_tolerance
83-
return LimitedMemoryBroydenCache{iip}(f, alg, u, du, fu, zero(fu),
88+
89+
abstol, reltol, termination_condition = _init_termination_elements(abstol,
90+
reltol,
91+
termination_condition,
92+
eltype(u))
93+
94+
mode = DiffEqBase.get_termination_mode(termination_condition)
95+
96+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
97+
nothing
98+
99+
return LimitedMemoryBroydenCache{iip}(f, alg, u, zero(u), du, fu, zero(fu),
84100
zero(fu), p, U, Vᵀ, similar(u, threshold), similar(u, 1, threshold),
85101
zero(u), zero(u), false, 0, 0, alg.max_resets, maxiters, internalnorm,
86-
ReturnCode.Default, abstol, reset_tolerance, reset_check, prob,
102+
ReturnCode.Default, abstol, reltol, reset_tolerance, reset_check, prob,
87103
NLStats(1, 0, 0, 0, 0),
88-
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
104+
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition,
105+
storage)
89106
end
90107

91108
function perform_step!(cache::LimitedMemoryBroydenCache{true})
92-
@unpack f, p, du, u = cache
109+
@unpack f, p, du, u, tc_storage = cache
93110
T = eltype(u)
94111

112+
termination_condition = cache.termination_condition(tc_storage)
113+
95114
α = perform_linesearch!(cache.lscache, u, du)
96115
_axpy!(α, du, u)
97116
f(cache.fu2, u, p)
98117

99-
cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
118+
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
119+
(cache.force_stop = true)
100120
cache.stats.nf += 1
101121

102122
cache.force_stop && return nothing
@@ -138,20 +158,25 @@ function perform_step!(cache::LimitedMemoryBroydenCache{true})
138158
cache.iterations_since_reset += 1
139159
end
140160

161+
cache.u_prev .= cache.u
141162
cache.fu .= cache.fu2
142163

143164
return nothing
144165
end
145166

146167
function perform_step!(cache::LimitedMemoryBroydenCache{false})
147-
@unpack f, p = cache
168+
@unpack f, p, tc_storage = cache
169+
170+
termination_condition = cache.termination_condition(tc_storage)
171+
148172
T = eltype(cache.u)
149173

150174
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
151175
cache.u = cache.u .+ α * cache.du
152176
cache.fu2 = f(cache.u, p)
153177

154-
cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
178+
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
179+
(cache.force_stop = true)
155180
cache.stats.nf += 1
156181

157182
cache.force_stop && return nothing
@@ -194,6 +219,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{false})
194219
cache.iterations_since_reset += 1
195220
end
196221

222+
cache.u_prev = @. cache.u
197223
cache.fu = cache.fu2
198224

199225
return nothing

src/raphson.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,8 @@ function perform_step!(cache::NewtonRaphsonCache{true})
135135
end
136136

137137
function perform_step!(cache::NewtonRaphsonCache{false})
138-
@unpack u, u_prev, fu1, f, p, alg, linsolve = cache
138+
@unpack u, u_prev, fu1, f, p, alg, linsolve, tc_storage = cache
139139

140-
tc_storage = cache.tc_storage
141140
termination_condition = cache.termination_condition(tc_storage)
142141

143142
cache.J = jacobian!!(cache.J, cache)

0 commit comments

Comments
 (0)