Skip to content

Commit 0256766

Browse files
Merge pull request #251 from yonatanwesen/yd/pseudotransient
Pseudo-Transient Method
2 parents 7fadae1 + 011a815 commit 0256766

File tree

4 files changed

+309
-1
lines changed

4 files changed

+309
-1
lines changed

src/NonlinearSolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ include("trustRegion.jl")
6767
include("levenberg.jl")
6868
include("gaussnewton.jl")
6969
include("dfsane.jl")
70+
include("pseudotransient.jl")
7071
include("jacobian.jl")
7172
include("ad.jl")
7273
include("default.jl")
@@ -95,7 +96,7 @@ end
9596

9697
export RadiusUpdateSchemes
9798

98-
export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton
99+
export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient
99100
export LeastSquaresOptimJL, FastLevenbergMarquardtJL
100101
export RobustMultiNewton, FastShortcutNonlinearPolyalg
101102

src/jacobian.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
9898
linprob = LinearProblem(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J,
9999
needsJᵀJ ? _vec(Jᵀfu) : _vec(fu); u0 = _vec(du))
100100

101+
if alg isa PseudoTransient
102+
alpha = convert(eltype(u), alg.alpha_initial)
103+
J_new = J - (1 / alpha) * I
104+
linprob = LinearProblem(J_new, _vec(fu); u0 = _vec(du))
105+
end
106+
101107
weight = similar(u)
102108
recursivefill!(weight, true)
103109

src/pseudotransient.jl

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""
2+
PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
3+
precs = DEFAULT_PRECS, alpha_initial = 1e-3, adkwargs...)
4+
5+
An implementation of PseudoTransient method that is used to solve steady state problems in an accelerated manner. It uses an adaptive time-stepping to
6+
integrate an initial value of nonlinear problem until sufficient accuracy in the desired steady-state is achieved to switch over to Newton's method and
7+
gain a rapid convergence. This implementation specifically uses "switched evolution relaxation" SER method. For detail information about the time-stepping and algorithm,
8+
please see the paper: [Coffey, Todd S. and Kelley, C. T. and Keyes, David E. (2003), Pseudotransient Continuation and Differential-Algebraic Equations,
9+
SIAM Journal on Scientific Computing,25, 553-569.](https://doi.org/10.1137/S106482750241044X)
10+
11+
### Keyword Arguments
12+
13+
- `alpha_initial` : the initial pseudo time step. it defaults to 1e-3. If it is small, you are going to need more iterations to converge.
14+
15+
16+
17+
18+
"""
19+
@concrete struct PseudoTransient{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
20+
ad::AD
21+
linsolve
22+
precs
23+
alpha_initial
24+
end
25+
26+
#concrete_jac(::PseudoTransient{CJ}) where {CJ} = CJ
27+
function set_ad(alg::PseudoTransient{CJ}, ad) where {CJ}
28+
return PseudoTransient{CJ}(ad, alg.linsolve, alg.precs, alg.alpha_initial)
29+
end
30+
31+
function PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
32+
precs = DEFAULT_PRECS, alpha_initial = 1e-3, adkwargs...)
33+
ad = default_adargs_to_adtype(; adkwargs...)
34+
return PseudoTransient{_unwrap_val(concrete_jac)}(ad, linsolve, precs, alpha_initial)
35+
end
36+
37+
@concrete mutable struct PseudoTransientCache{iip}
38+
f
39+
alg
40+
u
41+
fu1
42+
fu2
43+
du
44+
p
45+
alpha
46+
res_norm
47+
uf
48+
linsolve
49+
J
50+
jac_cache
51+
force_stop
52+
maxiters::Int
53+
internalnorm
54+
retcode::ReturnCode.T
55+
abstol
56+
prob
57+
stats::NLStats
58+
end
59+
60+
isinplace(::PseudoTransientCache{iip}) where {iip} = iip
61+
62+
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::PseudoTransient,
63+
args...;
64+
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
65+
linsolve_kwargs = (;),
66+
kwargs...) where {uType, iip}
67+
alg = get_concrete_algorithm(alg_, prob)
68+
69+
@unpack f, u0, p = prob
70+
u = alias_u0 ? u0 : deepcopy(u0)
71+
if iip
72+
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
73+
f(fu1, u, p)
74+
else
75+
fu1 = _mutable(f(u, p))
76+
end
77+
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg,
78+
f,
79+
u,
80+
p,
81+
Val(iip);
82+
linsolve_kwargs)
83+
alpha = convert(eltype(u), alg.alpha_initial)
84+
res_norm = internalnorm(fu1)
85+
86+
return PseudoTransientCache{iip}(f, alg, u, fu1, fu2, du, p, alpha, res_norm, uf,
87+
linsolve, J,
88+
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob,
89+
NLStats(1, 0, 0, 0, 0))
90+
end
91+
92+
function perform_step!(cache::PseudoTransientCache{true})
93+
@unpack u, fu1, f, p, alg, J, linsolve, du, alpha = cache
94+
jacobian!!(J, cache)
95+
J_new = J - (1 / alpha) * I
96+
97+
# u = u - J \ fu
98+
linres = dolinsolve(alg.precs, linsolve; A = J_new, b = _vec(fu1), linu = _vec(du),
99+
p, reltol = cache.abstol)
100+
cache.linsolve = linres.cache
101+
@. u = u - du
102+
f(fu1, u, p)
103+
104+
new_norm = cache.internalnorm(fu1)
105+
cache.alpha *= cache.res_norm / new_norm
106+
cache.res_norm = new_norm
107+
108+
new_norm < cache.abstol && (cache.force_stop = true)
109+
cache.stats.nf += 1
110+
cache.stats.njacs += 1
111+
cache.stats.nsolve += 1
112+
cache.stats.nfactors += 1
113+
return nothing
114+
end
115+
116+
function perform_step!(cache::PseudoTransientCache{false})
117+
@unpack u, fu1, f, p, alg, linsolve, alpha = cache
118+
119+
cache.J = jacobian!!(cache.J, cache)
120+
# u = u - J \ fu
121+
if linsolve === nothing
122+
cache.du = fu1 / (cache.J - (1 / alpha) * I)
123+
else
124+
linres = dolinsolve(alg.precs, linsolve; A = cache.J - (1 / alpha) * I,
125+
b = _vec(fu1),
126+
linu = _vec(cache.du), p, reltol = cache.abstol)
127+
cache.linsolve = linres.cache
128+
end
129+
cache.u = @. u - cache.du # `u` might not support mutation
130+
cache.fu1 = f(cache.u, p)
131+
132+
new_norm = cache.internalnorm(fu1)
133+
cache.alpha *= cache.res_norm / new_norm
134+
cache.res_norm = new_norm
135+
new_norm < cache.abstol && (cache.force_stop = true)
136+
cache.stats.nf += 1
137+
cache.stats.njacs += 1
138+
cache.stats.nsolve += 1
139+
cache.stats.nfactors += 1
140+
return nothing
141+
end
142+
143+
function SciMLBase.solve!(cache::PseudoTransientCache)
144+
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
145+
perform_step!(cache)
146+
cache.stats.nsteps += 1
147+
end
148+
149+
if cache.stats.nsteps == cache.maxiters
150+
cache.retcode = ReturnCode.MaxIters
151+
else
152+
cache.retcode = ReturnCode.Success
153+
end
154+
155+
return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1;
156+
cache.retcode, cache.stats)
157+
end
158+
159+
function SciMLBase.reinit!(cache::PseudoTransientCache{iip}, u0 = cache.u; p = cache.p,
160+
alpha_new,
161+
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
162+
cache.p = p
163+
if iip
164+
recursivecopy!(cache.u, u0)
165+
cache.f(cache.fu1, cache.u, p)
166+
else
167+
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
168+
cache.u = u0
169+
cache.fu1 = cache.f(cache.u, p)
170+
end
171+
cache.alpha = convert(eltype(cache.u), alpha_new)
172+
cache.res_norm = cache.internalnorm(cache.fu1)
173+
cache.abstol = abstol
174+
cache.maxiters = maxiters
175+
cache.stats.nf = 1
176+
cache.stats.nsteps = 1
177+
cache.force_stop = false
178+
cache.retcode = ReturnCode.Default
179+
return cache
180+
end

test/basictests.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,3 +543,124 @@ end
543543
end
544544
end
545545
end
546+
547+
# --- PseudoTransient tests ---
548+
549+
@testset "PseudoTransient" begin
550+
#these are tests for NewtonRaphson so we should set alpha_initial to be high so that we converge quickly
551+
552+
function benchmark_nlsolve_oop(f, u0, p = 2.0; alpha_initial = 10.0)
553+
prob = NonlinearProblem{false}(f, u0, p)
554+
return solve(prob, PseudoTransient(; alpha_initial), abstol = 1e-9)
555+
end
556+
557+
function benchmark_nlsolve_iip(f, u0, p = 2.0; linsolve, precs,
558+
alpha_initial = 10.0)
559+
prob = NonlinearProblem{true}(f, u0, p)
560+
return solve(prob, PseudoTransient(; linsolve, precs, alpha_initial), abstol = 1e-9)
561+
end
562+
563+
@testset "PT: alpha_initial = 10.0 PT AD: $(ad)" for ad in (AutoFiniteDiff(),
564+
AutoZygote())
565+
u0s = VERSION v"1.9" ? ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) : ([1.0, 1.0], 1.0)
566+
567+
@testset "[OOP] u0: $(typeof(u0))" for u0 in u0s
568+
sol = benchmark_nlsolve_oop(quadratic_f, u0)
569+
@test SciMLBase.successful_retcode(sol)
570+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
571+
572+
cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0),
573+
PseudoTransient(alpha_initial = 10.0),
574+
abstol = 1e-9)
575+
@test (@ballocated solve!($cache)) < 200
576+
end
577+
578+
precs = [NonlinearSolve.DEFAULT_PRECS, :Random]
579+
580+
@testset "[IIP] u0: $(typeof(u0)) precs: $(_nameof(prec)) linsolve: $(_nameof(linsolve))" for u0 in ([
581+
1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES())
582+
ad isa AutoZygote && continue
583+
if prec === :Random
584+
prec = (args...) -> (Diagonal(randn!(similar(u0))), nothing)
585+
end
586+
sol = benchmark_nlsolve_iip(quadratic_f!, u0; linsolve, precs = prec)
587+
@test SciMLBase.successful_retcode(sol)
588+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
589+
590+
cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0),
591+
PseudoTransient(; alpha_initial = 10.0, linsolve, precs = prec),
592+
abstol = 1e-9)
593+
@test (@ballocated solve!($cache)) 64
594+
end
595+
end
596+
597+
if VERSION v"1.9"
598+
@testset "[OOP] [Immutable AD]" begin
599+
for p in 1.0:0.1:100.0
600+
@test begin
601+
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
602+
res_true = sqrt(p)
603+
all(res.u .≈ res_true)
604+
end
605+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
606+
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
607+
end
608+
end
609+
end
610+
611+
@testset "[OOP] [Scalar AD]" begin
612+
for p in 1.0:0.1:100.0
613+
@test begin
614+
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
615+
res_true = sqrt(p)
616+
res.u res_true
617+
end
618+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
619+
p)
620+
1 / (2 * sqrt(p))
621+
end
622+
end
623+
624+
if VERSION v"1.9"
625+
t = (p) -> [sqrt(p[2] / p[1])]
626+
p = [0.9, 50.0]
627+
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
628+
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
629+
p)
630+
ForwardDiff.jacobian(t, p)
631+
end
632+
633+
function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip}
634+
probN = NonlinearProblem{iip}(f, iip ? [0.5] : 0.5, p_range[begin])
635+
cache = init(probN,
636+
PseudoTransient(alpha_initial = 10.0);
637+
maxiters = 100,
638+
abstol = 1e-10)
639+
sols = zeros(length(p_range))
640+
for (i, p) in enumerate(p_range)
641+
reinit!(cache, iip ? [cache.u[1]] : cache.u; p = p, alpha_new = 10.0)
642+
sol = solve!(cache)
643+
sols[i] = iip ? sol.u[1] : sol.u
644+
end
645+
return sols
646+
end
647+
p = range(0.01, 2, length = 200)
648+
@test nlprob_iterator_interface(quadratic_f, p, Val(false)) sqrt.(p)
649+
@test nlprob_iterator_interface(quadratic_f!, p, Val(true)) sqrt.(p)
650+
651+
@testset "ADType: $(autodiff) u0: $(_nameof(u0))" for autodiff in (false, true,
652+
AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),
653+
AutoSparseZygote(), AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0])
654+
probN = NonlinearProblem(quadratic_f, u0, 2.0)
655+
@test all(solve(probN, PseudoTransient(; alpha_initial = 10.0, autodiff)).u .≈
656+
sqrt(2.0))
657+
end
658+
659+
@testset "NewtonRaphson Fails but PT passes" begin # Test that `PseudoTransient` passes a test that `NewtonRaphson` fails on.
660+
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
661+
u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0]
662+
probN = NonlinearProblem{false}(newton_fails, u0, p)
663+
sol = solve(probN, PseudoTransient(alpha_initial = 1.0), abstol = 1e-10)
664+
@test all(abs.(newton_fails(sol.u, p)) .< 1e-10)
665+
end
666+
end

0 commit comments

Comments
 (0)