Skip to content

Commit 58bfc6a

Browse files
Merge #140
140: Refactor and house cleaning r=charleskawczynski a=charleskawczynski This PR fixes a few type issues Co-authored-by: Charles Kawczynski <[email protected]>
2 parents 9f7d443 + 6d0935b commit 58bfc6a

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

src/solvers/ark.jl

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,33 @@ The available implementations are:
1818
"""
1919
abstract type AdditiveRungeKutta <: DistributedODEAlgorithm end
2020

21-
struct AdditiveRungeKuttaTableau{Nstages, Nstages², RT}
21+
const T1TypeARK = NTuple{Nstages, RT} where {Nstages, RT}
22+
const T2TypeARK = SArray{NTuple{2, Nstages}, RT, 2, Nstages²} where {Nstages, RT, Nstages²}
23+
24+
struct AdditiveRungeKuttaTableau{T2 <: T2TypeARK, T1 <: T1TypeARK}
2225
"RK coefficient vector A (rhs scaling) for the explicit part"
23-
Aexpl::SArray{NTuple{2, Nstages}, RT, 2, Nstages²}
26+
Aexpl::T2
2427
"RK coefficient vector A (rhs scaling) for the implicit part"
25-
Aimpl::SArray{NTuple{2, Nstages}, RT, 2, Nstages²}
28+
Aimpl::T2
2629
"RK coefficient vector B (rhs add in scaling)"
27-
B::NTuple{Nstages, RT}
30+
B::T1
2831
"RK coefficient vector C (time scaling)"
29-
C::NTuple{Nstages, RT}
32+
C::T1
3033
end
34+
n_stages(::AdditiveRungeKuttaTableau{T2, T1}) where {T2, T1} = n_stages_ntuple(T1)
3135

32-
struct AdditiveRungeKuttaFullCache{Nstages, RT, A, O, L}
36+
struct AdditiveRungeKuttaFullCache{T1, T, A, O, L}
3337
"stage value of the state variable"
3438
U::A #Qstages
3539
"evaluated linear part of each stage ``f_L(U^{(i)})``"
36-
L::NTuple{Nstages, A} #Lstages
40+
L::T1 #Lstages
3741
"evaluated remainder part of each stage ``f_R(U^{(i)})``"
38-
R::NTuple{Nstages, A} #Rstages
39-
tableau::AdditiveRungeKuttaTableau{Nstages, RT}
42+
R::T1 #Rstages
43+
tableau::T
4044
W::O
4145
linsolve!::L
4246
end
43-
47+
n_stages(cache::AdditiveRungeKuttaFullCache) = n_stages(cache.tableau)
4448

4549
function init_cache(
4650
prob::DiffEqBase.AbstractODEProblem{uType, tType, true},
@@ -56,11 +60,12 @@ function init_cache(
5660
L = ntuple(i -> zero(prob.u0), Nstages)
5761
R = ntuple(i -> zero(prob.u0), Nstages)
5862

59-
if prob.f isa DiffEqBase.ODEFunction
60-
W = EulerOperator(prob.f.jvp, -dt * Aimpl[2, 2], prob.p, prob.tspan[1])
63+
f = if prob.f isa DiffEqBase.ODEFunction
64+
prob.f.jvp
6165
elseif prob.f isa DiffEqBase.SplitFunction
62-
W = EulerOperator(prob.f.f1, -dt * Aimpl[2, 2], prob.p, prob.tspan[1])
66+
prob.f.f1
6367
end
68+
W = EulerOperator(f, -dt * Aimpl[2, 2], prob.p, prob.tspan[1])
6469
linsolve! = alg.linsolve(Val{:init}, W, prob.u0; kwargs...)
6570

6671
AdditiveRungeKuttaFullCache(U, L, R, tab, W, linsolve!)
@@ -104,12 +109,13 @@ solve(prob, Rosenbrock23(linsolve=ColumnGMRES))
104109
# W = M - gamma*J <=> our EulerOperator
105110
# https://github.com/SciML/OrdinaryDiffEq.jl/blob/f93630317658b0c5460044a5d349f99391bc2f9c/src/derivative_utils.jl#L126
106111

107-
function step_u!(int, cache::AdditiveRungeKuttaFullCache{Nstages}) where {Nstages}
112+
function step_u!(int, cache::AdditiveRungeKuttaFullCache)
108113
step_u!(int, cache, int.sol.prob.f)
109114
end
110115

111-
function step_u!(int, cache::AdditiveRungeKuttaFullCache{Nstages}, f::DiffEqBase.SplitFunction) where {Nstages}
116+
function step_u!(int, cache::AdditiveRungeKuttaFullCache, f::DiffEqBase.SplitFunction)
112117

118+
Nstages = n_stages(cache)
113119
(; C, Aimpl, Aexpl, B) = cache.tableau
114120
(; U, R, L, W, linsolve) = cache
115121
(; u, p, t, dt) = int

0 commit comments

Comments
 (0)