@@ -145,7 +145,7 @@ function (::TRBDF2)(res, uₙ, Δt, f!, du, u, p, t, stages, stage)
145145 # after Bonaventura2021
146146 # They define the second stage as:
147147 # u - γ₂ * Δt * f(u, t+Δt) = (1-γ₃)uₙ + γ₃u₁
148- # Which differs from Bank1985
148+ # Which differs from Bank1985)
149149 # (2-γ)u + (1-γ)Δt * f(u, t+Δt) = 1/γ * u₁ - 1/γ * (1-γ)^2 * uₙ
150150 # In the sign of u - γ₂ * Δt
151151 # a₁ == (1-γ₃)
@@ -158,6 +158,38 @@ function (::TRBDF2)(res, uₙ, Δt, f!, du, u, p, t, stages, stage)
158158 end
159159end
160160
161+ abstract type DIRK{N} <: SimpleImplicitAlgorithm{N} end
162+
163+ struct RKImplicitEuler <: DIRK{1} end
164+
165+ function (:: RKImplicitEuler )(res, uₙ, Δt, f!, du, u, p, t, stages, stage, RK)
166+ if stage == 1
167+ # Stage 1:
168+ f! (du, u, p, t + RK. c[stage] * Δt)
169+ return res .= u .- uₙ .- RK. a[stage, stage] * Δt .* du
170+ else
171+ @. u = uₙ + RK. b[1 ] * Δt * stages[1 ]
172+ end
173+
174+ end
175+
176+ struct KS2 <: DIRK{2} end
177+ struct QZ2 <: DIRK{2} end
178+ struct Crouzeix <: DIRK{2} end
179+
180+ function (:: DIRK{2} )(res, uₙ, Δt, f!, du, u, p, t, stages, stage, RK)
181+ if stage == 1
182+ f! (du, u, p, t + RK. c[stage] * Δt)
183+ return res .= u .- uₙ .- RK. a[stage, stage] * Δt .* du
184+ elseif stage == 2
185+ f! (du, u, p, t + RK. c[stage] * Δt)
186+ return res .= u .- uₙ .- RK. a[stage, 1 ] * Δt .* stages[1 ] - RK. a[stage,2 ] * Δt .* du
187+ else
188+ @. u = uₙ + Δt * (RK. b[1 ] * stages[1 ] + RK. b[2 ] * stages[2 ])
189+ end
190+
191+ end
192+
161193struct Rosenbrock <: Direct{3} end
162194
163195function (:: Rosenbrock )(res, uₙ, Δt, f!, du, u, p, t, stages, stage, workspace, M, RK)
@@ -191,6 +223,12 @@ struct RosenbrockButcher{T1 <: AbstractArray, T2 <: AbstractArray} <: RKTableau
191223 m:: T2
192224end
193225
226+ struct DIRKButcher{T1 <: AbstractArray , T2 <: AbstractArray } <: RKTableau
227+ a:: T1
228+ b:: T2
229+ c:: T2
230+ end
231+
194232function RosenbrockTableau ()
195233
196234 # SSP - Knoth
@@ -219,6 +257,72 @@ function RosenbrockTableau()
219257
220258end
221259
260+ function ImplicitEulerTableau ()
261+
262+ nstage = 1
263+ a = zeros (Float64, nstage, nstage)
264+ a[1 ,1 ] = 1
265+
266+ b = zeros (Float64, nstage)
267+ b[1 ] = 1
268+
269+ c = zeros (Float64, nstage)
270+ c[1 ] = 1
271+ return DIRKButcher (a,b,c)
272+ end
273+
274+ # Kraaijevanger and Spijker's two-stage Diagonally Implicit Runge–Kutta method:
275+ function KS2Tableau ()
276+ nstage = 2
277+ a = zeros (Float64, nstage, nstage)
278+ a[1 ,1 ] = 1 / 2
279+ a[2 ,1 ] = - 1 / 2
280+ a[2 ,2 ] = 2
281+ b = zeros (Float64, nstage)
282+ b[1 ] = - 1 / 2
283+ b[2 ] = 3 / 2
284+
285+ c = zeros (Float64, nstage)
286+ c[1 ] = 1 / 2
287+ c[2 ] = 3 / 2
288+ return DIRKButcher (a,b,c)
289+
290+ end
291+ # Qin and Zhang's two-stage, 2nd order, symplectic Diagonally Implicit Runge–Kutta method:
292+ function QZ2Tableau ()
293+ nstage = 2
294+ a = zeros (Float64, nstage, nstage)
295+ a[1 ,1 ] = 1 / 4
296+ a[2 ,1 ] = 1 / 2
297+ a[2 ,2 ] = 1 / 4
298+ b = zeros (Float64, nstage)
299+ b[1 ] = 1 / 2
300+ b[2 ] = 1 / 2
301+
302+ c = zeros (Float64, nstage)
303+ c[1 ] = 1 / 4
304+ c[2 ] = 3 / 4
305+ return DIRKButcher (a,b,c)
306+ end
307+
308+ # Crouzeix's two-stage, 3rd order Diagonally Implicit Runge–Kutta method
309+ function CrouzeixTableau ()
310+ nstage = 2
311+ a = zeros (Float64, nstage, nstage)
312+ a[1 ,1 ] = 1 / 2 + sqrt (3 )/ 6
313+ a[2 ,1 ] = - sqrt (3 )/ 3
314+ a[2 ,2 ] = 1 / 2 + sqrt (3 )/ 6
315+ b = zeros (Float64, nstage)
316+ b[1 ] = 1 / 2
317+ b[2 ] = 1 / 2
318+
319+ c = zeros (Float64, nstage)
320+ c[1 ] = 1 / 2 + sqrt (3 )/ 6
321+ c[2 ] = 1 / 2 - sqrt (3 )/ 6
322+ return DIRKButcher (a,b,c)
323+ end
324+
325+
222326function RKTableau (alg:: Direct )
223327 return RosenbrockTableau ()
224328end
@@ -227,11 +331,30 @@ function RKTableau(alg::NonDirect)
227331 return RosenbrockTableau ()
228332end
229333
334+ function RKTableau (alg:: RKImplicitEuler )
335+ return ImplicitEulerTableau ()
336+ end
337+
338+ function RKTableau (alg:: KS2 )
339+ return KS2Tableau ()
340+ end
341+
342+ function RKTableau (alg:: QZ2 )
343+ return QZ2Tableau ()
344+ end
345+
346+ function RKTableau (alg:: Crouzeix )
347+ return CrouzeixTableau ()
348+ end
230349
231350function nonlinear_problem (alg:: SimpleImplicitAlgorithm , f:: F ) where {F}
232351 return (res, u, (uₙ, Δt, du, p, t, stages, stage)) -> alg (res, uₙ, Δt, f, du, u, p, t, stages, stage)
233352end
234353
354+ function nonlinear_problem (alg:: DIRK , f:: F ) where {F}
355+ return (res, u, (uₙ, Δt, du, p, t, stages, stage, RK)) -> alg (res, uₙ, Δt, f, du, u, p, t, stages, stage, RK)
356+ end
357+
235358# This struct is needed to fake https://github.com/SciML/OrdinaryDiffEq.jl/blob/0c2048a502101647ac35faabd80da8a5645beac7/src/integrators/type.jl#L1
236359mutable struct SimpleImplicitOptions{Callback}
237360 callback:: Callback # callbacks; used in Trixi.jl
@@ -368,6 +491,27 @@ function stage!(integrator, alg::NonDirect)
368491 end
369492end
370493
494+ function stage! (integrator, alg:: DIRK )
495+ for stage in 1 : stages (alg)
496+ F! = nonlinear_problem (alg, integrator. f)
497+ # TODO : Pass in `stages[1:(stage-1)]` or full tuple?
498+ _, stats = Ariadne. newton_krylov! (
499+ F!, integrator. u_tmp, (integrator. u, integrator. dt, integrator. du, integrator. p, integrator. t, integrator. stages, stage, integrator. RK), integrator. res;
500+ verbose = integrator. opts. verbose, krylov_kwargs = integrator. opts. krylov_kwargs,
501+ algo = integrator. opts. algo, tol_abs = 6.0e-6 ,
502+ )
503+ @assert stats. solved
504+
505+ # Store the solution for each stage in stages
506+ integrator. f (integrator. du, integrator. u_tmp, integrator. p, integrator. t + integrator. RK. c[stage] * integrator. dt)
507+ integrator. stages[stage] .= integrator. du
508+ if stage == stages (alg)
509+ alg (integrator. res, integrator. u, integrator. dt, integrator. f, integrator. du, integrator. u_tmp, integrator. p, integrator. t, integrator. stages, stage+ 1 , integrator. RK)
510+ end
511+
512+ end
513+ end
514+
371515function stage! (integrator, alg:: Direct )
372516
373517 F! (du, u, p) = integrator. f (du, u, p, integrator. t)
0 commit comments