Skip to content

Commit f934680

Browse files
committed
Start using termination conditions in newton raphson
1 parent 7fadae1 commit f934680

File tree

2 files changed

+42
-10
lines changed

2 files changed

+42
-10
lines changed

src/raphson.jl

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,32 @@ for large-scale and numerically-difficult nonlinear systems.
3030
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
3131
used here directly, and they will be converted to the correct `LineSearch`.
3232
"""
33-
@concrete struct NewtonRaphson{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
33+
@concrete struct NewtonRaphson{CJ, AD, TC <: NLSolveTerminationCondition} <: AbstractNewtonAlgorithm{CJ, AD}
3434
ad::AD
3535
linsolve
3636
precs
3737
linesearch
38+
termination_condition::TC
3839
end
3940

4041
function set_ad(alg::NewtonRaphson{CJ}, ad) where {CJ}
4142
return NewtonRaphson{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch)
4243
end
4344

4445
function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
45-
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
46+
linesearch = LineSearch(), precs = DEFAULT_PRECS, termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
47+
abstol = nothing,
48+
reltol = nothing), adkwargs...)
4649
ad = default_adargs_to_adtype(; adkwargs...)
4750
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
48-
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)
51+
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch, termination_condition)
4952
end
5053

5154
@concrete mutable struct NewtonRaphsonCache{iip} <: AbstractNonlinearSolveCache{iip}
5255
f
5356
alg
5457
u
58+
uprev
5559
fu1
5660
fu2
5761
du
@@ -65,9 +69,11 @@ end
6569
internalnorm
6670
retcode::ReturnCode.T
6771
abstol
72+
reltol
6873
prob
6974
stats::NLStats
7075
lscache
76+
tc_storage
7177
end
7278

7379
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphson, args...;
@@ -80,15 +86,28 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso
8086
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
8187
linsolve_kwargs)
8288

83-
return NewtonRaphsonCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
84-
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob,
85-
NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip)))
89+
90+
tc = alg.termination_condition
91+
mode = DiffEqBase.get_termination_mode(tc)
92+
93+
atol = _get_tolerance(abstol, tc.abstol, eltype(u))
94+
rtol = _get_tolerance(reltol, tc.reltol, eltype(u))
95+
96+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
97+
nothing
98+
99+
return NewtonRaphsonCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve, J,
100+
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol, prob,
101+
NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip)), storage)
86102
end
87103

88104
function perform_step!(cache::NewtonRaphsonCache{true})
89-
@unpack u, fu1, f, p, alg, J, linsolve, du = cache
105+
@unpack u, uprev, fu1, f, p, alg, J, linsolve, du = cache
90106
jacobian!!(J, cache)
91107

108+
tc_storage = cache.tc_storage
109+
termination_condition = cache.alg.termination_condition(tc_storage)
110+
92111
# u = u - J \ fu
93112
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du),
94113
p, reltol = cache.abstol)
@@ -99,7 +118,9 @@ function perform_step!(cache::NewtonRaphsonCache{true})
99118
@. u = u - α * du
100119
f(cache.fu1, u, p)
101120

102-
cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true)
121+
termination_condition(cache.fu1, u, uprev, cache.abstol, cache.reltol) && (cache.force_stop = true)
122+
123+
@. uprev = u
103124
cache.stats.nf += 1
104125
cache.stats.njacs += 1
105126
cache.stats.nsolve += 1
@@ -108,7 +129,11 @@ function perform_step!(cache::NewtonRaphsonCache{true})
108129
end
109130

110131
function perform_step!(cache::NewtonRaphsonCache{false})
111-
@unpack u, fu1, f, p, alg, linsolve = cache
132+
@unpack u, uprev, fu1, f, p, alg, linsolve = cache
133+
134+
tc_storage = cache.tc_storage
135+
termination_condition = cache.alg.termination_condition(tc_storage)
136+
112137

113138
cache.J = jacobian!!(cache.J, cache)
114139
# u = u - J \ fu
@@ -125,7 +150,9 @@ function perform_step!(cache::NewtonRaphsonCache{false})
125150
cache.u = @. u - α * cache.du # `u` might not support mutation
126151
cache.fu1 = f(cache.u, p)
127152

128-
cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true)
153+
termination_condition(cache.fu1, cache.u, uprev, cache.abstol, cache.reltol) && (cache.force_stop = true)
154+
155+
cache.uprev = cache.u
129156
cache.stats.nf += 1
130157
cache.stats.njacs += 1
131158
cache.stats.nsolve += 1

src/utils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,8 @@ function __get_concrete_algorithm(alg, prob)
208208
end
209209
return set_ad(alg, ad)
210210
end
211+
212+
function _get_tolerance(η, tc_η, ::Type{T}) where {T}
213+
@show fallback_η
214+
return ifelse!== nothing, η, ifelse(tc_η !== nothing, tc_η, fallback_η))
215+
end

0 commit comments

Comments
 (0)