Skip to content

Commit 0eb715d

Browse files
committed
forgot to add the pt file
1 parent d92f2d7 commit 0eb715d

File tree

1 file changed

+180
-0
lines changed

1 file changed

+180
-0
lines changed

src/pseudotransient.jl

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""
2+
PseudoTransient{nothing, AutoForwardDiff{0, Bool}, Nothing, typeof(NonlinearSolve.DEFAULT_PRECS), Float64}(AutoForwardDiff{0, Bool}(true),
3+
nothing, NonlinearSolve.DEFAULT_PRECS, 0.001)
4+
5+
An implementation of PseudoTransient method that is used to solve steady state problems in an accelerated manner. It uses an adaptive time-stepping to
6+
integrate an initial value of nonlinear problem until sufficient accuracy in the desired steady-state is achieved to switch over to Newton's method and
7+
gain a rapid convergence. This implementation specifically uses "switched evolution relaxation" SER method. For detail information about the time-stepping and algorithm,
8+
please see the paper: [Coffey, Todd S. and Kelley, C. T. and Keyes, David E. (2003), Pseudotransient Continuation and Differential-Algebraic Equations,
9+
SIAM Journal on Scientific Computing,25, 553-569.](https://doi.org/10.1137/S106482750241044X)
10+
11+
### Keyword Arguments
12+
13+
- `alpha_initial` : the initial pseudo time step. it defaults to 1e-3. If it is small, you are going to need more iterations to converge.
14+
15+
16+
17+
18+
"""
19+
@concrete struct PseudoTransient{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
20+
ad::AD
21+
linsolve
22+
precs
23+
alpha_initial
24+
end
25+
26+
#concrete_jac(::PseudoTransient{CJ}) where {CJ} = CJ
27+
function set_ad(alg::PseudoTransient{CJ}, ad) where {CJ}
28+
return PseudoTransient{CJ}(ad, alg.linsolve, alg.precs, alg.alpha_initial)
29+
end
30+
31+
function PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
32+
precs = DEFAULT_PRECS, alpha_initial = 1e-3, adkwargs...)
33+
ad = default_adargs_to_adtype(; adkwargs...)
34+
return PseudoTransient{_unwrap_val(concrete_jac)}(ad, linsolve, precs, alpha_initial)
35+
end
36+
37+
@concrete mutable struct PseudoTransientCache{iip}
38+
f
39+
alg
40+
u
41+
fu1
42+
fu2
43+
du
44+
p
45+
alpha
46+
res_norm
47+
uf
48+
linsolve
49+
J
50+
jac_cache
51+
force_stop
52+
maxiters::Int
53+
internalnorm
54+
retcode::ReturnCode.T
55+
abstol
56+
prob
57+
stats::NLStats
58+
end
59+
60+
isinplace(::PseudoTransientCache{iip}) where {iip} = iip
61+
62+
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::PseudoTransient,
63+
args...;
64+
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
65+
linsolve_kwargs = (;),
66+
kwargs...) where {uType, iip}
67+
alg = get_concrete_algorithm(alg_, prob)
68+
69+
@unpack f, u0, p = prob
70+
u = alias_u0 ? u0 : deepcopy(u0)
71+
if iip
72+
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
73+
f(fu1, u, p)
74+
else
75+
fu1 = _mutable(f(u, p))
76+
end
77+
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg,
78+
f,
79+
u,
80+
p,
81+
Val(iip);
82+
linsolve_kwargs)
83+
alpha = convert(eltype(u), alg.alpha_initial)
84+
res_norm = internalnorm(fu1)
85+
86+
return PseudoTransientCache{iip}(f, alg, u, fu1, fu2, du, p, alpha, res_norm, uf,
87+
linsolve, J,
88+
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob,
89+
NLStats(1, 0, 0, 0, 0))
90+
end
91+
92+
function perform_step!(cache::PseudoTransientCache{true})
93+
@unpack u, fu1, f, p, alg, J, linsolve, du, alpha = cache
94+
jacobian!!(J, cache)
95+
J_new = J - (1 / alpha) * I
96+
97+
# u = u - J \ fu
98+
linres = dolinsolve(alg.precs, linsolve; A = J_new, b = _vec(fu1), linu = _vec(du),
99+
p, reltol = cache.abstol)
100+
cache.linsolve = linres.cache
101+
@. u = u - du
102+
f(fu1, u, p)
103+
104+
new_norm = cache.internalnorm(fu1)
105+
cache.alpha *= cache.res_norm / new_norm
106+
cache.res_norm = new_norm
107+
108+
new_norm < cache.abstol && (cache.force_stop = true)
109+
cache.stats.nf += 1
110+
cache.stats.njacs += 1
111+
cache.stats.nsolve += 1
112+
cache.stats.nfactors += 1
113+
return nothing
114+
end
115+
116+
function perform_step!(cache::PseudoTransientCache{false})
117+
@unpack u, fu1, f, p, alg, linsolve, alpha = cache
118+
119+
cache.J = jacobian!!(cache.J, cache)
120+
# u = u - J \ fu
121+
if linsolve === nothing
122+
cache.du = fu1 / (cache.J - (1 / alpha) * I)
123+
else
124+
linres = dolinsolve(alg.precs, linsolve; A = cache.J - (1 / alpha) * I,
125+
b = _vec(fu1),
126+
linu = _vec(cache.du), p, reltol = cache.abstol)
127+
cache.linsolve = linres.cache
128+
end
129+
cache.u = @. u - cache.du # `u` might not support mutation
130+
cache.fu1 = f(cache.u, p)
131+
132+
new_norm = cache.internalnorm(fu1)
133+
cache.alpha *= cache.res_norm / new_norm
134+
cache.res_norm = new_norm
135+
new_norm < cache.abstol && (cache.force_stop = true)
136+
cache.stats.nf += 1
137+
cache.stats.njacs += 1
138+
cache.stats.nsolve += 1
139+
cache.stats.nfactors += 1
140+
return nothing
141+
end
142+
143+
function SciMLBase.solve!(cache::PseudoTransientCache)
144+
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
145+
perform_step!(cache)
146+
cache.stats.nsteps += 1
147+
end
148+
149+
if cache.stats.nsteps == cache.maxiters
150+
cache.retcode = ReturnCode.MaxIters
151+
else
152+
cache.retcode = ReturnCode.Success
153+
end
154+
155+
return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1;
156+
cache.retcode, cache.stats)
157+
end
158+
159+
function SciMLBase.reinit!(cache::PseudoTransientCache{iip}, u0 = cache.u; p = cache.p,
160+
alpha_new,
161+
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
162+
cache.p = p
163+
if iip
164+
recursivecopy!(cache.u, u0)
165+
cache.f(cache.fu1, cache.u, p)
166+
else
167+
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
168+
cache.u = u0
169+
cache.fu1 = cache.f(cache.u, p)
170+
end
171+
cache.alpha = convert(eltype(cache.u), alpha_new)
172+
cache.res_norm = cache.internalnorm(cache.fu1)
173+
cache.abstol = abstol
174+
cache.maxiters = maxiters
175+
cache.stats.nf = 1
176+
cache.stats.nsteps = 1
177+
cache.force_stop = false
178+
cache.retcode = ReturnCode.Default
179+
return cache
180+
end

0 commit comments

Comments
 (0)