Skip to content

Commit 6d0935b

Browse files
Fix types and bugs in ark
1 parent dd9574a commit 6d0935b

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

src/solvers/ark.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,33 @@ The available implementations are:
1818
"""
1919
abstract type AdditiveRungeKutta <: DistributedODEAlgorithm end
2020

21-
struct AdditiveRungeKuttaTableau{Nstages, Nstages², RT}
21+
const T1TypeARK = NTuple{Nstages, RT} where {Nstages, RT}
22+
const T2TypeARK = SArray{NTuple{2, Nstages}, RT, 2, Nstages²} where {Nstages, RT, Nstages²}
23+
24+
struct AdditiveRungeKuttaTableau{T2 <: T2TypeARK, T1 <: T1TypeARK}
2225
"RK coefficient vector A (rhs scaling) for the explicit part"
23-
Aexpl::SArray{NTuple{2, Nstages}, RT, 2, Nstages²}
26+
Aexpl::T2
2427
"RK coefficient vector A (rhs scaling) for the implicit part"
25-
Aimpl::SArray{NTuple{2, Nstages}, RT, 2, Nstages²}
28+
Aimpl::T2
2629
"RK coefficient vector B (rhs add in scaling)"
27-
B::NTuple{Nstages, RT}
30+
B::T1
2831
"RK coefficient vector C (time scaling)"
29-
C::NTuple{Nstages, RT}
32+
C::T1
3033
end
34+
n_stages(::AdditiveRungeKuttaTableau{T2, T1}) where {T2, T1} = n_stages_ntuple(T1)
3135

32-
struct AdditiveRungeKuttaFullCache{Nstages, RT, A, O, L}
36+
struct AdditiveRungeKuttaFullCache{T1, T, A, O, L}
3337
"stage value of the state variable"
3438
U::A #Qstages
3539
"evaluated linear part of each stage ``f_L(U^{(i)})``"
36-
L::NTuple{Nstages, A} #Lstages
40+
L::T1 #Lstages
3741
"evaluated remainder part of each stage ``f_R(U^{(i)})``"
38-
R::NTuple{Nstages, A} #Rstages
39-
tableau::AdditiveRungeKuttaTableau{Nstages, RT}
42+
R::T1 #Rstages
43+
tableau::T
4044
W::O
4145
linsolve!::L
4246
end
43-
47+
n_stages(cache::AdditiveRungeKuttaFullCache) = n_stages(cache.tableau)
4448

4549
function init_cache(
4650
prob::DiffEqBase.AbstractODEProblem{uType, tType, true},
@@ -105,12 +109,13 @@ solve(prob, Rosenbrock23(linsolve=ColumnGMRES))
105109
# W = M - gamma*J <=> our EulerOperator
106110
# https://github.com/SciML/OrdinaryDiffEq.jl/blob/f93630317658b0c5460044a5d349f99391bc2f9c/src/derivative_utils.jl#L126
107111

108-
function step_u!(int, cache::AdditiveRungeKuttaFullCache{Nstages}) where {Nstages}
112+
function step_u!(int, cache::AdditiveRungeKuttaFullCache)
109113
step_u!(int, cache, int.sol.prob.f)
110114
end
111115

112-
function step_u!(int, cache::AdditiveRungeKuttaFullCache{Nstages}, f::DiffEqBase.SplitFunction) where {Nstages}
116+
function step_u!(int, cache::AdditiveRungeKuttaFullCache, f::DiffEqBase.SplitFunction)
113117

118+
Nstages = n_stages(cache)
114119
(; C, Aimpl, Aexpl, B) = cache.tableau
115120
(; U, R, L, W, linsolve) = cache
116121
(; u, p, t, dt) = int

0 commit comments

Comments
 (0)