@@ -18,29 +18,33 @@ The available implementations are:
1818"""
1919abstract type AdditiveRungeKutta <: DistributedODEAlgorithm end
2020
21- struct AdditiveRungeKuttaTableau{Nstages, Nstages², RT}
21+ const T1TypeARK = NTuple{Nstages, RT} where {Nstages, RT}
22+ const T2TypeARK = SArray{NTuple{2 , Nstages}, RT, 2 , Nstages²} where {Nstages, RT, Nstages²}
23+
24+ struct AdditiveRungeKuttaTableau{T2 <: T2TypeARK , T1 <: T1TypeARK }
2225 " RK coefficient vector A (rhs scaling) for the explicit part"
23- Aexpl:: SArray{NTuple{2, Nstages}, RT, 2, Nstages²}
26+ Aexpl:: T2
2427 " RK coefficient vector A (rhs scaling) for the implicit part"
25- Aimpl:: SArray{NTuple{2, Nstages}, RT, 2, Nstages²}
28+ Aimpl:: T2
2629 " RK coefficient vector B (rhs add in scaling)"
27- B:: NTuple{Nstages, RT}
30+ B:: T1
2831 " RK coefficient vector C (time scaling)"
29- C:: NTuple{Nstages, RT}
32+ C:: T1
3033end
34+ n_stages (:: AdditiveRungeKuttaTableau{T2, T1} ) where {T2, T1} = n_stages_ntuple (T1)
3135
32- struct AdditiveRungeKuttaFullCache{Nstages, RT , A, O, L}
36+ struct AdditiveRungeKuttaFullCache{T1, T , A, O, L}
3337 " stage value of the state variable"
3438 U:: A # Qstages
3539 " evaluated linear part of each stage ``f_L(U^{(i)})``"
36- L:: NTuple{Nstages, A} # Lstages
40+ L:: T1 # Lstages
3741 " evaluated remainder part of each stage ``f_R(U^{(i)})``"
38- R:: NTuple{Nstages, A} # Rstages
39- tableau:: AdditiveRungeKuttaTableau{Nstages, RT}
42+ R:: T1 # Rstages
43+ tableau:: T
4044 W:: O
4145 linsolve!:: L
4246end
43-
47+ n_stages (cache :: AdditiveRungeKuttaFullCache ) = n_stages (cache . tableau)
4448
4549function init_cache (
4650 prob:: DiffEqBase.AbstractODEProblem{uType, tType, true} ,
@@ -56,11 +60,12 @@ function init_cache(
5660 L = ntuple (i -> zero (prob. u0), Nstages)
5761 R = ntuple (i -> zero (prob. u0), Nstages)
5862
59- if prob. f isa DiffEqBase. ODEFunction
60- W = EulerOperator ( prob. f. jvp, - dt * Aimpl[ 2 , 2 ], prob . p, prob . tspan[ 1 ])
63+ f = if prob. f isa DiffEqBase. ODEFunction
64+ prob. f. jvp
6165 elseif prob. f isa DiffEqBase. SplitFunction
62- W = EulerOperator ( prob. f. f1, - dt * Aimpl[ 2 , 2 ], prob . p, prob . tspan[ 1 ])
66+ prob. f. f1
6367 end
68+ W = EulerOperator (f, - dt * Aimpl[2 , 2 ], prob. p, prob. tspan[1 ])
6469 linsolve! = alg. linsolve (Val{:init }, W, prob. u0; kwargs... )
6570
6671 AdditiveRungeKuttaFullCache (U, L, R, tab, W, linsolve!)
@@ -104,12 +109,13 @@ solve(prob, Rosenbrock23(linsolve=ColumnGMRES))
104109# W = M - gamma*J <=> our EulerOperator
105110# https://github.com/SciML/OrdinaryDiffEq.jl/blob/f93630317658b0c5460044a5d349f99391bc2f9c/src/derivative_utils.jl#L126
106111
107- function step_u! (int, cache:: AdditiveRungeKuttaFullCache{Nstages} ) where {Nstages}
112+ function step_u! (int, cache:: AdditiveRungeKuttaFullCache )
108113 step_u! (int, cache, int. sol. prob. f)
109114end
110115
111- function step_u! (int, cache:: AdditiveRungeKuttaFullCache{Nstages} , f:: DiffEqBase.SplitFunction ) where {Nstages}
116+ function step_u! (int, cache:: AdditiveRungeKuttaFullCache , f:: DiffEqBase.SplitFunction )
112117
118+ Nstages = n_stages (cache)
113119 (; C, Aimpl, Aexpl, B) = cache. tableau
114120 (; U, R, L, W, linsolve) = cache
115121 (; u, p, t, dt) = int
0 commit comments