Skip to content

Commit 6e81d6f

Browse files
committed
add DIRK
1 parent bea7d6a commit 6e81d6f

File tree

2 files changed

+181
-1
lines changed

2 files changed

+181
-1
lines changed

examples/trixi_dirk.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# # Using the an implicit solver based on Ariadne with Trixi.jl
2+
3+
using Trixi
4+
using Implicit
5+
using CairoMakie
6+
7+
8+
# Notes:
9+
# Must disable both Polyester and LoopVectorization for Enzyme to be able to differentiate Trixi.jl
10+
# Using https://github.com/trixi-framework/Trixi.jl/pull/2295
11+
#
12+
# LocalPreferences.jl
13+
# ```toml
14+
# [Trixi]
15+
# loop_vectorization = false
16+
# polyester = false
17+
# ```
18+
19+
@assert !Trixi._PREFERENCE_POLYESTER
20+
@assert !Trixi._PREFERENCE_LOOPVECTORIZATION
21+
22+
trixi_include(joinpath(examples_dir(), "tree_2d_dgsem", "elixir_advection_basic.jl"), cfl = 10.0, sol = nothing, save_solution = nothing);
23+
24+
###############################################################################
25+
# run the simulation
26+
27+
sol = solve(
28+
ode,
29+
# Implicit.RKImplicitEuler();
30+
# Implicit.KS2();
31+
Implicit.Crouzeix();
32+
dt = 1.0, # solve needs some value here but it will be overwritten by the stepsize_callback
33+
ode_default_options()..., callback = callbacks,
34+
# verbose=1,
35+
krylov_algo = :gmres,
36+
);

libs/Implicit/src/Implicit.jl

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ function (::TRBDF2)(res, uₙ, Δt, f!, du, u, p, t, stages, stage)
145145
# after Bonaventura2021
146146
# They define the second stage as:
147147
# u - γ₂ * Δt * f(u, t+Δt) = (1-γ₃)uₙ + γ₃u₁
148-
# Which differs from Bank1985
148+
# Which differs from Bank1985)
149149
# (2-γ)u + (1-γ)Δt * f(u, t+Δt) = 1/γ * u₁ - 1/γ * (1-γ)^2 * uₙ
150150
# In the sign of u - γ₂ * Δt
151151
# a₁ == (1-γ₃)
@@ -158,6 +158,38 @@ function (::TRBDF2)(res, uₙ, Δt, f!, du, u, p, t, stages, stage)
158158
end
159159
end
160160

161+
abstract type DIRK{N} <: SimpleImplicitAlgorithm{N} end
162+
163+
struct RKImplicitEuler <: DIRK{1} end
164+
165+
function (::RKImplicitEuler)(res, uₙ, Δt, f!, du, u, p, t, stages, stage, RK)
166+
if stage == 1
167+
# Stage 1:
168+
f!(du, u, p, t + RK.c[stage] * Δt)
169+
return res .= u .- uₙ .- RK.a[stage, stage] * Δt .* du
170+
else
171+
@. u = uₙ + RK.b[1] * Δt * stages[1]
172+
end
173+
174+
end
175+
176+
struct KS2 <: DIRK{2} end
177+
struct QZ2 <: DIRK{2} end
178+
struct Crouzeix <: DIRK{2} end
179+
180+
function (::DIRK{2})(res, uₙ, Δt, f!, du, u, p, t, stages, stage, RK)
181+
if stage == 1
182+
f!(du, u, p, t + RK.c[stage] * Δt)
183+
return res .= u .- uₙ .- RK.a[stage, stage] * Δt .* du
184+
elseif stage == 2
185+
f!(du, u, p, t + RK.c[stage] * Δt)
186+
return res .= u .- uₙ .- RK.a[stage, 1] * Δt .* stages[1] - RK.a[stage,2] * Δt .* du
187+
else
188+
@. u = uₙ + Δt * (RK.b[1] * stages[1] + RK.b[2] * stages[2])
189+
end
190+
191+
end
192+
161193
struct Rosenbrock <: Direct{3} end
162194

163195
function (::Rosenbrock)(res, uₙ, Δt, f!, du, u, p, t, stages, stage, workspace, M, RK)
@@ -191,6 +223,12 @@ struct RosenbrockButcher{T1 <: AbstractArray, T2 <: AbstractArray} <: RKTableau
191223
m::T2
192224
end
193225

226+
struct DIRKButcher{T1 <: AbstractArray, T2 <: AbstractArray} <: RKTableau
227+
a::T1
228+
b::T2
229+
c::T2
230+
end
231+
194232
function RosenbrockTableau()
195233

196234
# SSP - Knoth
@@ -219,6 +257,72 @@ function RosenbrockTableau()
219257

220258
end
221259

260+
function ImplicitEulerTableau()
261+
262+
nstage = 1
263+
a = zeros(Float64, nstage, nstage)
264+
a[1,1] = 1
265+
266+
b = zeros(Float64, nstage)
267+
b[1] = 1
268+
269+
c = zeros(Float64, nstage)
270+
c[1] = 1
271+
return DIRKButcher(a,b,c)
272+
end
273+
274+
# Kraaijevanger and Spijker's two-stage Diagonally Implicit Runge–Kutta method:
275+
function KS2Tableau()
276+
nstage = 2
277+
a = zeros(Float64, nstage, nstage)
278+
a[1,1] = 1/2
279+
a[2,1] = -1/2
280+
a[2,2] = 2
281+
b = zeros(Float64, nstage)
282+
b[1] = -1/2
283+
b[2] = 3/2
284+
285+
c = zeros(Float64, nstage)
286+
c[1] = 1/2
287+
c[2] = 3/2
288+
return DIRKButcher(a,b,c)
289+
290+
end
291+
# Qin and Zhang's two-stage, 2nd order, symplectic Diagonally Implicit Runge–Kutta method:
292+
function QZ2Tableau()
293+
nstage = 2
294+
a = zeros(Float64, nstage, nstage)
295+
a[1,1] = 1/4
296+
a[2,1] = 1/2
297+
a[2,2] = 1/4
298+
b = zeros(Float64, nstage)
299+
b[1] = 1/2
300+
b[2] = 1/2
301+
302+
c = zeros(Float64, nstage)
303+
c[1] = 1/4
304+
c[2] = 3/4
305+
return DIRKButcher(a,b,c)
306+
end
307+
308+
# Crouzeix's two-stage, 3rd order Diagonally Implicit Runge–Kutta method
309+
function CrouzeixTableau()
310+
nstage = 2
311+
a = zeros(Float64, nstage, nstage)
312+
a[1,1] = 1/2 + sqrt(3)/6
313+
a[2,1] = -sqrt(3)/3
314+
a[2,2] = 1/2 + sqrt(3)/6
315+
b = zeros(Float64, nstage)
316+
b[1] = 1/2
317+
b[2] = 1/2
318+
319+
c = zeros(Float64, nstage)
320+
c[1] = 1/2 + sqrt(3)/6
321+
c[2] = 1/2 - sqrt(3)/6
322+
return DIRKButcher(a,b,c)
323+
end
324+
325+
222326
function RKTableau(alg::Direct)
223327
return RosenbrockTableau()
224328
end
@@ -227,11 +331,30 @@ function RKTableau(alg::NonDirect)
227331
return RosenbrockTableau()
228332
end
229333

334+
function RKTableau(alg::RKImplicitEuler)
335+
return ImplicitEulerTableau()
336+
end
337+
338+
function RKTableau(alg::KS2)
339+
return KS2Tableau()
340+
end
341+
342+
function RKTableau(alg::QZ2)
343+
return QZ2Tableau()
344+
end
345+
346+
function RKTableau(alg::Crouzeix)
347+
return CrouzeixTableau()
348+
end
230349

231350
function nonlinear_problem(alg::SimpleImplicitAlgorithm, f::F) where {F}
232351
return (res, u, (uₙ, Δt, du, p, t, stages, stage)) -> alg(res, uₙ, Δt, f, du, u, p, t, stages, stage)
233352
end
234353

354+
function nonlinear_problem(alg::DIRK, f::F) where {F}
355+
return (res, u, (uₙ, Δt, du, p, t, stages, stage, RK)) -> alg(res, uₙ, Δt, f, du, u, p, t, stages, stage, RK)
356+
end
357+
235358
# This struct is needed to fake https://github.com/SciML/OrdinaryDiffEq.jl/blob/0c2048a502101647ac35faabd80da8a5645beac7/src/integrators/type.jl#L1
236359
mutable struct SimpleImplicitOptions{Callback}
237360
callback::Callback # callbacks; used in Trixi.jl
@@ -368,6 +491,27 @@ function stage!(integrator, alg::NonDirect)
368491
end
369492
end
370493

494+
function stage!(integrator, alg::DIRK)
495+
for stage in 1:stages(alg)
496+
F! = nonlinear_problem(alg, integrator.f)
497+
# TODO: Pass in `stages[1:(stage-1)]` or full tuple?
498+
_, stats = Ariadne.newton_krylov!(
499+
F!, integrator.u_tmp, (integrator.u, integrator.dt, integrator.du, integrator.p, integrator.t, integrator.stages, stage, integrator.RK), integrator.res;
500+
verbose = integrator.opts.verbose, krylov_kwargs = integrator.opts.krylov_kwargs,
501+
algo = integrator.opts.algo, tol_abs = 6.0e-6,
502+
)
503+
@assert stats.solved
504+
505+
# Store the solution for each stage in stages
506+
integrator.f(integrator.du, integrator.u_tmp, integrator.p, integrator.t + integrator.RK.c[stage] * integrator.dt)
507+
integrator.stages[stage] .= integrator.du
508+
if stage == stages(alg)
509+
alg(integrator.res, integrator.u, integrator.dt, integrator.f, integrator.du, integrator.u_tmp, integrator.p, integrator.t, integrator.stages, stage+1, integrator.RK)
510+
end
511+
512+
end
513+
end
514+
371515
function stage!(integrator, alg::Direct)
372516

373517
F!(du, u, p) = integrator.f(du, u, p, integrator.t)

0 commit comments

Comments
 (0)