Skip to content

Commit 416e656

Browse files
committed
Add termination condition to Pseudotransient methods
1 parent 814f704 commit 416e656

File tree

2 files changed

+64
-11
lines changed

2 files changed

+64
-11
lines changed

src/pseudotransient.jl

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
precs = DEFAULT_PRECS, alpha_initial = 1e-3, adkwargs...)
44
55
An implementation of PseudoTransient method that is used to solve steady state problems in an accelerated manner. It uses an adaptive time-stepping to
6-
integrate an initial value of nonlinear problem until sufficient accuracy in the desired steady-state is achieved to switch over to Newton's method and
6+
integrate an initial value of nonlinear problem until sufficient accuracy in the desired steady-state is achieved to switch over to Newton's method and
77
gain a rapid convergence. This implementation specifically uses "switched evolution relaxation" SER method. For detail information about the time-stepping and algorithm,
8-
please see the paper: [Coffey, Todd S. and Kelley, C. T. and Keyes, David E. (2003), Pseudotransient Continuation and Differential-Algebraic Equations,
8+
please see the paper: [Coffey, Todd S. and Kelley, C. T. and Keyes, David E. (2003), Pseudotransient Continuation and Differential-Algebraic Equations,
99
SIAM Journal on Scientific Computing,25, 553-569.](https://doi.org/10.1137/S106482750241044X)
1010
1111
### Keyword Arguments
@@ -27,7 +27,7 @@ SIAM Journal on Scientific Computing,25, 553-569.](https://doi.org/10.1137/S1064
2727
preconditioners. For more information on specifying preconditioners for LinearSolve
2828
algorithms, consult the
2929
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
30-
- `alpha_initial` : the initial pseudo time step. it defaults to 1e-3. If it is small,
30+
- `alpha_initial` : the initial pseudo time step. it defaults to 1e-3. If it is small,
3131
you are going to need more iterations to converge but it can be more stable.
3232
"""
3333
@concrete struct PseudoTransient{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
@@ -52,6 +52,7 @@ end
5252
f
5353
alg
5454
u
55+
u_prev
5556
fu1
5657
fu2
5758
du
@@ -67,15 +68,19 @@ end
6768
internalnorm
6869
retcode::ReturnCode.T
6970
abstol
71+
reltol
7072
prob
7173
stats::NLStats
74+
termination_condition
75+
tc_storage
7276
end
7377

7478
isinplace(::PseudoTransientCache{iip}) where {iip} = iip
7579

7680
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::PseudoTransient,
7781
args...;
78-
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
82+
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
83+
termination_condition = nothing, internalnorm = DEFAULT_NORM,
7984
linsolve_kwargs = (;),
8085
kwargs...) where {uType, iip}
8186
alg = get_concrete_algorithm(alg_, prob)
@@ -93,16 +98,30 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::PseudoTransi
9398
alpha = convert(eltype(u), alg.alpha_initial)
9499
res_norm = internalnorm(fu1)
95100

96-
return PseudoTransientCache{iip}(f, alg, u, fu1, fu2, du, p, alpha, res_norm, uf,
101+
abstol, reltol, termination_condition = _init_termination_elements(abstol,
102+
reltol,
103+
termination_condition,
104+
eltype(u))
105+
106+
mode = DiffEqBase.get_termination_mode(termination_condition)
107+
108+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
109+
nothing
110+
111+
return PseudoTransientCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, alpha, res_norm,
112+
uf,
97113
linsolve, J, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
98-
prob, NLStats(1, 0, 0, 0, 0))
114+
reltol,
115+
prob, NLStats(1, 0, 0, 0, 0), termination_condition, storage)
99116
end
100117

101118
function perform_step!(cache::PseudoTransientCache{true})
102-
@unpack u, fu1, f, p, alg, J, linsolve, du, alpha = cache
119+
@unpack u, u_prev, fu1, f, p, alg, J, linsolve, du, alpha, tc_storage = cache
103120
jacobian!!(J, cache)
104121
J_new = J - (1 / alpha) * I
105122

123+
termination_condition = cache.termination_condition(tc_storage)
124+
106125
# u = u - J \ fu
107126
linres = dolinsolve(alg.precs, linsolve; A = J_new, b = _vec(fu1), linu = _vec(du),
108127
p, reltol = cache.abstol)
@@ -114,7 +133,10 @@ function perform_step!(cache::PseudoTransientCache{true})
114133
cache.alpha *= cache.res_norm / new_norm
115134
cache.res_norm = new_norm
116135

117-
new_norm < cache.abstol && (cache.force_stop = true)
136+
termination_condition(fu1, u, u_prev, cache.abstol, cache.reltol) &&
137+
(cache.force_stop = true)
138+
139+
@. u_prev = u
118140
cache.stats.nf += 1
119141
cache.stats.njacs += 1
120142
cache.stats.nsolve += 1
@@ -123,7 +145,10 @@ function perform_step!(cache::PseudoTransientCache{true})
123145
end
124146

125147
function perform_step!(cache::PseudoTransientCache{false})
126-
@unpack u, fu1, f, p, alg, linsolve, alpha = cache
148+
@unpack u, u_prev, fu1, f, p, alg, linsolve, alpha, tc_storage = cache
149+
150+
tc_storage = cache.tc_storage
151+
termination_condition = cache.termination_condition(tc_storage)
127152

128153
cache.J = jacobian!!(cache.J, cache)
129154
# u = u - J \ fu
@@ -141,7 +166,9 @@ function perform_step!(cache::PseudoTransientCache{false})
141166
new_norm = cache.internalnorm(fu1)
142167
cache.alpha *= cache.res_norm / new_norm
143168
cache.res_norm = new_norm
144-
new_norm < cache.abstol && (cache.force_stop = true)
169+
termination_condition(fu1, cache.u, u_prev, cache.abstol, cache.reltol) &&
170+
(cache.force_stop = true)
171+
cache.u_prev = @. cache.u
145172
cache.stats.nf += 1
146173
cache.stats.njacs += 1
147174
cache.stats.nsolve += 1
@@ -167,7 +194,9 @@ end
167194

168195
function SciMLBase.reinit!(cache::PseudoTransientCache{iip}, u0 = cache.u; p = cache.p,
169196
alpha_new,
170-
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
197+
abstol = cache.abstol, reltol = cache.reltol,
198+
termination_condition = cache.termination_condition,
199+
maxiters = cache.maxiters) where {iip}
171200
cache.p = p
172201
if iip
173202
recursivecopy!(cache.u, u0)
@@ -177,9 +206,17 @@ function SciMLBase.reinit!(cache::PseudoTransientCache{iip}, u0 = cache.u; p = c
177206
cache.u = u0
178207
cache.fu1 = cache.f(cache.u, p)
179208
end
209+
210+
termination_condition = _get_reinit_termination_condition(cache,
211+
abstol,
212+
reltol,
213+
termination_condition)
214+
180215
cache.alpha = convert(eltype(cache.u), alpha_new)
181216
cache.res_norm = cache.internalnorm(cache.fu1)
182217
cache.abstol = abstol
218+
cache.reltol = reltol
219+
cache.termination_condition = termination_condition
183220
cache.maxiters = maxiters
184221
cache.stats.nf = 1
185222
cache.stats.nsteps = 1

test/basictests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,22 @@ end
720720
sol = solve(probN, PseudoTransient(alpha_initial = 1.0), abstol = 1e-10)
721721
@test all(abs.(newton_fails(sol.u, p)) .< 1e-10)
722722
end
723+
724+
@testset "Termination condition: $(mode) u0: $(_nameof(u0))" for mode in instances(NLSolveTerminationMode.T),
725+
u0 in (1.0, [1.0, 1.0])
726+
727+
if mode
728+
(NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.RelSafeBest,
729+
NLSolveTerminationMode.AbsSafeBest)
730+
continue
731+
end
732+
termination_condition = NLSolveTerminationCondition(mode; abstol = nothing,
733+
reltol = nothing)
734+
probN = NonlinearProblem(quadratic_f, u0, 2.0)
735+
@test all(solve(probN,
736+
PseudoTransient(; alpha_initial = 10.0);
737+
termination_condition).u .≈ sqrt(2.0))
738+
end
723739
end
724740

725741
# --- GeneralBroyden tests ---

0 commit comments

Comments
 (0)