Skip to content

Commit dbed34c

Browse files
Merge pull request #208 from utkarsh530/u/termination_condition
Start using termination conditions from DiffEqBase
2 parents 191a237 + 350fac5 commit dbed34c

13 files changed

+548
-92
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
7979
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
8080
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8181
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
82+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
8283

8384
[targets]
84-
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices"]
85+
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase"]

src/broyden.jl

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ end
3131
f
3232
alg
3333
u
34+
u_prev
3435
du
3536
fu
3637
fu2
@@ -46,17 +47,21 @@ end
4647
internalnorm
4748
retcode::ReturnCode.T
4849
abstol
50+
reltol
4951
reset_tolerance
5052
reset_check
5153
prob
5254
stats::NLStats
5355
lscache
56+
termination_condition
57+
tc_storage
5458
end
5559

5660
get_fu(cache::GeneralBroydenCache) = cache.fu
5761

5862
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...;
59-
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
63+
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
64+
termination_condition = nothing, internalnorm = DEFAULT_NORM,
6065
kwargs...) where {uType, iip}
6166
@unpack f, u0, p = prob
6267
u = alias_u0 ? u0 : deepcopy(u0)
@@ -65,23 +70,38 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
6570
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(eltype(u))) :
6671
alg.reset_tolerance
6772
reset_check = x -> abs(x) reset_tolerance
68-
return GeneralBroydenCache{iip}(f, alg, u, _mutable_zero(u), fu, zero(fu),
73+
74+
abstol, reltol, termination_condition = _init_termination_elements(abstol,
75+
reltol,
76+
termination_condition,
77+
eltype(u))
78+
79+
mode = DiffEqBase.get_termination_mode(termination_condition)
80+
81+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
82+
nothing
83+
return GeneralBroydenCache{iip}(f, alg, u, zero(u), _mutable_zero(u), fu, zero(fu),
6984
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
70-
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance,
85+
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
86+
reset_tolerance,
7187
reset_check, prob, NLStats(1, 0, 0, 0, 0),
72-
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
88+
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition,
89+
storage)
7390
end
7491

7592
function perform_step!(cache::GeneralBroydenCache{true})
76-
@unpack f, p, du, fu, fu2, dfu, u, J⁻¹, J⁻¹df, J⁻¹₂ = cache
93+
@unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂, tc_storage = cache
94+
95+
termination_condition = cache.termination_condition(tc_storage)
7796
T = eltype(u)
7897

7998
mul!(_vec(du), J⁻¹, -_vec(fu))
8099
α = perform_linesearch!(cache.lscache, u, du)
81100
_axpy!(α, du, u)
82101
f(fu2, u, p)
83102

84-
cache.internalnorm(fu2) < cache.abstol && (cache.force_stop = true)
103+
termination_condition(fu2, u, u_prev, cache.abstol, cache.reltol) &&
104+
(cache.force_stop = true)
85105
cache.stats.nf += 1
86106

87107
cache.force_stop && return nothing
@@ -106,20 +126,25 @@ function perform_step!(cache::GeneralBroydenCache{true})
106126
mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1)
107127
end
108128
fu .= fu2
129+
@. u_prev = u
109130

110131
return nothing
111132
end
112133

113134
function perform_step!(cache::GeneralBroydenCache{false})
114-
@unpack f, p = cache
135+
@unpack f, p, tc_storage = cache
136+
137+
termination_condition = cache.termination_condition(tc_storage)
138+
115139
T = eltype(cache.u)
116140

117141
cache.du = _restructure(cache.du, cache.J⁻¹ * -_vec(cache.fu))
118142
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
119143
cache.u = cache.u .+ α * cache.du
120144
cache.fu2 = f(cache.u, p)
121145

122-
cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
146+
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
147+
(cache.force_stop = true)
123148
cache.stats.nf += 1
124149

125150
cache.force_stop && return nothing
@@ -142,12 +167,15 @@ function perform_step!(cache::GeneralBroydenCache{false})
142167
cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂
143168
end
144169
cache.fu = cache.fu2
170+
cache.u_prev = @. cache.u
145171

146172
return nothing
147173
end
148174

149175
function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p,
150-
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
176+
abstol = cache.abstol, reltol = cache.reltol,
177+
termination_condition = cache.termination_condition,
178+
maxiters = cache.maxiters) where {iip}
151179
cache.p = p
152180
if iip
153181
recursivecopy!(cache.u, u0)
@@ -157,7 +185,14 @@ function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = ca
157185
cache.u = u0
158186
cache.fu = cache.f(cache.u, p)
159187
end
188+
termination_condition = _get_reinit_termination_condition(cache,
189+
abstol,
190+
reltol,
191+
termination_condition)
192+
160193
cache.abstol = abstol
194+
cache.reltol = reltol
195+
cache.termination_condition = termination_condition
161196
cache.maxiters = maxiters
162197
cache.stats.nf = 1
163198
cache.stats.nsteps = 1

src/dfsane.jl

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,16 @@ end
8888
internalnorm
8989
retcode::SciMLBase.ReturnCode.T
9090
abstol
91+
reltol
9192
prob
9293
stats::NLStats
94+
termination_condition
95+
tc_storage
9396
end
9497

9598
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...;
96-
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
99+
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
100+
termination_condition = nothing, internalnorm = DEFAULT_NORM,
97101
kwargs...) where {uType, iip}
98102
uₙ = alias_u0 ? prob.u0 : deepcopy(prob.u0)
99103

@@ -122,14 +126,27 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args.
122126
f₍ₙₒᵣₘ₎₀ = f₍ₙₒᵣₘ₎ₙ₋₁
123127

124128
= fill(f₍ₙₒᵣₘ₎ₙ₋₁, M)
129+
130+
abstol, reltol, termination_condition = _init_termination_elements(abstol,
131+
reltol,
132+
termination_condition,
133+
T)
134+
135+
mode = DiffEqBase.get_termination_mode(termination_condition)
136+
137+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
138+
nothing
139+
125140
return DFSaneCache{iip}(alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀,
126141
M, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, p, false, maxiters,
127-
internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0))
142+
internalnorm, ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0),
143+
termination_condition, storage)
128144
end
129145

130146
function perform_step!(cache::DFSaneCache{true})
131-
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
147+
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache
132148

149+
termination_condition = cache.termination_condition(tc_storage)
133150
f = (dx, x) -> cache.prob.f(dx, x, cache.p)
134151

135152
T = eltype(cache.uₙ)
@@ -174,7 +191,7 @@ function perform_step!(cache::DFSaneCache{true})
174191
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
175192
end
176193

177-
if cache.internalnorm(cache.fuₙ) < cache.abstol
194+
if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol)
178195
cache.force_stop = true
179196
end
180197

@@ -205,8 +222,9 @@ function perform_step!(cache::DFSaneCache{true})
205222
end
206223

207224
function perform_step!(cache::DFSaneCache{false})
208-
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
225+
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache
209226

227+
termination_condition = cache.termination_condition(tc_storage)
210228
f = x -> cache.prob.f(x, cache.p)
211229

212230
T = eltype(cache.uₙ)
@@ -249,7 +267,7 @@ function perform_step!(cache::DFSaneCache{false})
249267
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
250268
end
251269

252-
if cache.internalnorm(cache.fuₙ) < cache.abstol
270+
if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol)
253271
cache.force_stop = true
254272
end
255273

@@ -296,7 +314,9 @@ function SciMLBase.solve!(cache::DFSaneCache)
296314
end
297315

298316
function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p,
299-
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
317+
abstol = cache.abstol, reltol = cache.reltol,
318+
termination_condition = cache.termination_condition,
319+
maxiters = cache.maxiters) where {iip}
300320
cache.p = p
301321
if iip
302322
recursivecopy!(cache.uₙ, u0)
@@ -317,7 +337,14 @@ function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p
317337
T = eltype(cache.uₙ)
318338
cache.σₙ = T(cache.alg.σ_1)
319339

340+
termination_condition = _get_reinit_termination_condition(cache,
341+
abstol,
342+
reltol,
343+
termination_condition)
344+
320345
cache.abstol = abstol
346+
cache.reltol = reltol
347+
cache.termination_condition = termination_condition
321348
cache.maxiters = maxiters
322349
cache.stats.nf = 1
323350
cache.stats.nsteps = 1

src/gaussnewton.jl

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,16 @@ end
4949
function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
5050
precs = DEFAULT_PRECS, adkwargs...)
5151
ad = default_adargs_to_adtype(; adkwargs...)
52-
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
52+
return GaussNewton{_unwrap_val(concrete_jac)}(ad,
53+
linsolve,
54+
precs)
5355
end
5456

5557
@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
5658
f
5759
alg
5860
u
61+
u_prev
5962
fu1
6063
fu2
6164
fu_new
@@ -72,12 +75,17 @@ end
7275
internalnorm
7376
retcode::ReturnCode.T
7477
abstol
78+
reltol
7579
prob
7680
stats::NLStats
81+
tc_storage
82+
termination_condition
7783
end
7884

7985
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton,
80-
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
86+
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
87+
termination_condition = nothing,
88+
internalnorm = DEFAULT_NORM,
8189
kwargs...) where {uType, iip}
8290
alg = get_concrete_algorithm(alg_, prob)
8391
@unpack f, u0, p = prob
@@ -101,15 +109,29 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
101109
JᵀJ, Jᵀf = nothing, nothing
102110
end
103111

104-
return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J,
112+
abstol, reltol, termination_condition = _init_termination_elements(abstol,
113+
reltol,
114+
termination_condition,
115+
eltype(u); mode = NLSolveTerminationMode.AbsNorm)
116+
117+
mode = DiffEqBase.get_termination_mode(termination_condition)
118+
119+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
120+
nothing
121+
122+
return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
123+
linsolve, J,
105124
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
106-
prob, NLStats(1, 0, 0, 0, 0))
125+
reltol,
126+
prob, NLStats(1, 0, 0, 0, 0), storage, termination_condition)
107127
end
108128

109129
function perform_step!(cache::GaussNewtonCache{true})
110-
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
130+
@unpack u, u_prev, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du, tc_storage = cache
111131
jacobian!!(J, cache)
112132

133+
termination_condition = cache.termination_condition(tc_storage)
134+
113135
if JᵀJ !== nothing
114136
__matmul!(JᵀJ, J', J)
115137
__matmul!(Jᵀf, J', fu1)
@@ -127,9 +149,15 @@ function perform_step!(cache::GaussNewtonCache{true})
127149
@. u = u - du
128150
f(cache.fu_new, u, p)
129151

130-
(cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol ||
131-
cache.internalnorm(cache.fu_new) < cache.abstol) &&
152+
(termination_condition(cache.fu_new .- cache.fu1,
153+
cache.u,
154+
u_prev,
155+
cache.abstol,
156+
cache.reltol) ||
157+
termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol)) &&
132158
(cache.force_stop = true)
159+
160+
@. u_prev = u
133161
cache.fu1 .= cache.fu_new
134162
cache.stats.nf += 1
135163
cache.stats.njacs += 1
@@ -139,7 +167,9 @@ function perform_step!(cache::GaussNewtonCache{true})
139167
end
140168

141169
function perform_step!(cache::GaussNewtonCache{false})
142-
@unpack u, fu1, f, p, alg, linsolve = cache
170+
@unpack u, u_prev, fu1, f, p, alg, linsolve, tc_storage = cache
171+
172+
termination_condition = cache.termination_condition(tc_storage)
143173

144174
cache.J = jacobian!!(cache.J, cache)
145175

@@ -164,7 +194,10 @@ function perform_step!(cache::GaussNewtonCache{false})
164194
cache.u = @. u - cache.du # `u` might not support mutation
165195
cache.fu_new = f(cache.u, p)
166196

167-
(cache.internalnorm(cache.fu_new) < cache.abstol) && (cache.force_stop = true)
197+
termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol) &&
198+
(cache.force_stop = true)
199+
200+
cache.u_prev = @. cache.u
168201
cache.fu1 = cache.fu_new
169202
cache.stats.nf += 1
170203
cache.stats.njacs += 1
@@ -174,7 +207,9 @@ function perform_step!(cache::GaussNewtonCache{false})
174207
end
175208

176209
function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache.p,
177-
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
210+
abstol = cache.abstol, reltol = cache.reltol,
211+
termination_condition = cache.termination_condition,
212+
maxiters = cache.maxiters) where {iip}
178213
cache.p = p
179214
if iip
180215
recursivecopy!(cache.u, u0)
@@ -184,7 +219,14 @@ function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache
184219
cache.u = u0
185220
cache.fu1 = cache.f(cache.u, p)
186221
end
222+
termination_condition = _get_reinit_termination_condition(cache,
223+
abstol,
224+
reltol,
225+
termination_condition)
226+
187227
cache.abstol = abstol
228+
cache.reltol = reltol
229+
cache.termination_condition = termination_condition
188230
cache.maxiters = maxiters
189231
cache.stats.nf = 1
190232
cache.stats.nsteps = 1

0 commit comments

Comments
 (0)