Skip to content

Commit 1a1df0d

Browse files
Merge #133
133: Refactor, unpack efficiently, unify n_stages nstages r=charleskawczynski a=charleskawczynski This PR applies a few refactoring changes: - Changes syntax `a = x.a` to `(; a) = x` for several variables - Adds a `n_stages` method for some caches - Renames `nstages` to `n_stages` Co-authored-by: Charles Kawczynski <[email protected]>
2 parents 792875f + d830bf3 commit 1a1df0d

File tree

7 files changed

+109
-127
lines changed

7 files changed

+109
-127
lines changed

src/solvers/ark.jl

Lines changed: 49 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,16 @@ function init_cache(
5050
) where {uType, tType}
5151

5252
tab = tableau(alg, eltype(prob.u0))
53-
Nstages = length(tab.B) # TODO: create function for this
53+
(; B, Aimpl) = tab
54+
Nstages = length(B) # TODO: create function for this
5455
U = zero(prob.u0)
5556
L = ntuple(i -> zero(prob.u0), Nstages)
5657
R = ntuple(i -> zero(prob.u0), Nstages)
5758

5859
if prob.f isa DiffEqBase.ODEFunction
59-
W = EulerOperator(prob.f.jvp, -dt * tab.Aimpl[2, 2], prob.p, prob.tspan[1])
60+
W = EulerOperator(prob.f.jvp, -dt * Aimpl[2, 2], prob.p, prob.tspan[1])
6061
elseif prob.f isa DiffEqBase.SplitFunction
61-
W = EulerOperator(prob.f.f1, -dt * tab.Aimpl[2, 2], prob.p, prob.tspan[1])
62+
W = EulerOperator(prob.f.f1, -dt * Aimpl[2, 2], prob.p, prob.tspan[1])
6263
end
6364
linsolve! = alg.linsolve(Val{:init}, W, prob.u0; kwargs...)
6465

@@ -108,95 +109,88 @@ function step_u!(int, cache::AdditiveRungeKuttaFullCache{Nstages}) where {Nstage
108109
end
109110

110111
function step_u!(int, cache::AdditiveRungeKuttaFullCache{Nstages}, f::DiffEqBase.SplitFunction) where {Nstages}
111-
tab = cache.tableau
112-
U = cache.U
113-
Uhat = cache.R[end] # can be used as work array, as only used in last stage
114112

113+
(; C, Aimpl, Aexpl, B) = cache.tableau
114+
(; U, R, L, W, linsolve) = cache
115+
(; u, p, t, dt) = int
115116

116-
u = int.u
117-
p = int.p
118-
t = int.t
119-
dt = int.dt
117+
U = U
118+
Uhat = R[end] # can be used as work array, as only used in last stage
120119

121120
fL! = f.f1 # linear part
122121
fR! = f.f2 # remainder
123122

124123
# first stage is always explicit
125-
τ = t + tab.C[1] * dt
124+
τ = t + C[1] * dt
126125

127-
# cache.L[i] .= fL(cache.U[i-1], p, t + tab.C[i]*dt)
128-
# cache.R[1] .= fR(cache.U[i-1], p, t + tab.C[i]*dt)
126+
# L[i] .= fL(U[i-1], p, t + C[i]*dt)
127+
# R[1] .= fR(U[i-1], p, t + C[i]*dt)
129128

130-
fL!(cache.L[1], u, p, τ)
131-
fR!(cache.R[1], u, p, τ)
129+
fL!(L[1], u, p, τ)
130+
fR!(R[1], u, p, τ)
132131

133132
for i in 2:Nstages
134133
# solve for W * U = Uhat
135134
# set U to initial guess based on fully explicit
136135
# TODO: we don't need this for direct solves
137136
U .= Uhat .= u
138137
for j in 1:(i - 1)
139-
Uhat .+= (dt * tab.Aimpl[i, j]) .* cache.L[j] .+ (dt * tab.Aexpl[i, j]) .* cache.R[j]
138+
Uhat .+= (dt * Aimpl[i, j]) .* L[j] .+ (dt * Aexpl[i, j]) .* R[j]
140139
# initial value: we only need to do this if using an iterative method:
141-
U .+= (dt * tab.Aexpl[i, j]) .* (cache.L[j] .+ cache.R[j])
140+
U .+= (dt * Aexpl[i, j]) .* (L[j] .+ R[j])
142141
end
143142

144143
# W = I - dt * Aimpl[i,i] * L
145144
# currently only use SDIRK methods where
146145
# Aimpl[i,i] = i == 1 ? 0 : const
147146
# TODO: handle changing dt & Aimpl[i,i] coeffs
148-
if !(DiffEqBase.isconstant(cache.W))
149-
cache.W.t = τ
147+
if !(DiffEqBase.isconstant(W))
148+
W.t = τ
150149
W_updated = true
151150
end
152-
γ = -dt * tab.Aimpl[i, i]
153-
if cache.W.γ != γ
154-
cache.W.γ = γ
151+
γ = -dt * Aimpl[i, i]
152+
if W.γ != γ
153+
W.γ = γ
155154
W_updated = true
156155
end
157-
cache.linsolve!(U, cache.W, Uhat, W_updated)
156+
linsolve!(U, W, Uhat, W_updated)
158157

159-
τ = t + tab.C[i] * dt
160-
fL!(cache.L[i], U, p, τ)
161-
# or use cache.L[i] .= (U .- Uhat) ./ (dt * tab.Aimpl[i,i]) ?
162-
fR!(cache.R[i], U, p, τ)
158+
τ = t + C[i] * dt
159+
fL!(L[i], U, p, τ)
160+
# or use L[i] .= (U .- Uhat) ./ (dt * Aimpl[i,i]) ?
161+
fR!(R[i], U, p, τ)
163162
end
164163

165164
# compute next step
166165
for i in 1:Nstages
167-
u .+= (dt * tab.B[i]) .* (cache.L[i] .+ cache.R[i])
166+
u .+= (dt * B[i]) .* (L[i] .+ R[i])
168167
end
169168
end
170169

171170
# WIP
172171
function step_u!(int, cache::AdditiveRungeKuttaFullCache{Nstages}, f::DiffEqBase.ODEFunction) where {Nstages}
173-
tab = cache.tableau
174-
U = cache.U
175-
Uhat = cache.R[end] # can be used as work array, as only used in last stage
176-
177-
178-
u = int.u
179-
p = int.p
180-
t = int.t
181-
dt = int.dt
172+
(; C, Aimpl, Aexpl, B) = cache.tableau
173+
(; u, p, t, dt) = int
174+
(; U, R, L, W, linsolve) = cache
175+
Uhat = R[end] # can be used as work array, as only used in last stage
182176

183177
fL! = f.jvp # linear part
184178
f! = f.f # remainder
185179

186180
# first stage is always explicit
187-
τ = t + tab.C[1] * dt
181+
τ = t + C[1] * dt
188182

189-
f!(cache.R[1], u, p, τ)
190-
fL!(cache.L[1], u, p, τ)
183+
f!(R[1], u, p, τ)
184+
fL!(L[1], u, p, τ)
191185
for i in 2:Nstages
192186
# solve for W * U = Uhat
193187
# set U to initial guess based on fully explicit
194188
# TODO: we don't need this for direct solves
195189
U .= Uhat .= u
196190
for j in 1:(i - 1)
197-
Uhat .+= (dt * tab.Aimpl[i, j]) .* cache.L[j] .+ (dt * tab.Aexpl[i, j]) .* (cache.R[j] .- cache.L[j])
191+
Uhat .+= (dt * Aimpl[i, j]) .* L[j] .+ (dt * Aexpl[i, j]) .* (R[j] .- L[j])
198192
# initial value: we only need to do this if using an iterative method:
199-
U .+= (dt * tab.Aexpl[i, j]) .* cache.R[j]
193+
U .+= (dt * Aexpl[i, j]) .* R[j]
200194
end
201195
#=
202196
# solve for W * U = Uhat
@@ -205,36 +199,36 @@ function step_u!(int, cache::AdditiveRungeKuttaFullCache{Nstages}, f::DiffEqBase
205199
V .= u
206200
Ω .= 0
207201
for j = 1:i-1
208-
V .+= dt * tab.Aexpl[i,j] .* cache.R[j]
209-
Ω .+= (tab.Aimpl[i,j]-tab.Aexpl[i,j])/tab.Aimpl[i,i] .* cache.U[j]
202+
V .+= dt * Aexpl[i,j] .* R[j]
203+
Ω .+= (Aimpl[i,j]-Aexpl[i,j])/Aimpl[i,i] .* U[j]
210204
end
211205
Vhat .= V .+ Ω
212206
=#
213207
# W = I - dt * Aimpl[i,i] * L
214208
# currently only use SDIRK methods where
215209
# Aimpl[i,i] = i == 1 ? 0 : const
216210
# TODO: handle changing dt & Aimpl[i,i] coeffs
217-
if !(DiffEqBase.isconstant(cache.W))
218-
cache.W.t = τ
211+
if !(DiffEqBase.isconstant(W))
212+
W.t = τ
219213
W_updated = true
220214
end
221-
γ = -dt * tab.Aimpl[i, i]
222-
if cache.W.γ != γ
223-
cache.W.γ = γ
215+
γ = -dt * Aimpl[i, i]
216+
if W.γ != γ
217+
W.γ = γ
224218
W_updated = true
225219
end
226-
cache.linsolve!(U, cache.W, Uhat, W_updated)
220+
linsolve!(U, W, Uhat, W_updated)
227221
# U = V .- Ω
228222

229-
τ = t + tab.C[i] * dt
230-
fL!(cache.L[i], U, p, τ)
231-
# or use cache.L[i] .= (U .- Uhat) ./ (dt * tab.Aimpl[i,i]) ?
232-
f!(cache.R[i], U, p, τ)
223+
τ = t + C[i] * dt
224+
fL!(L[i], U, p, τ)
225+
# or use L[i] .= (U .- Uhat) ./ (dt * Aimpl[i,i]) ?
226+
f!(R[i], U, p, τ)
233227
end
234228

235229
# compute next step
236230
for i in 1:Nstages
237-
u .+= (dt * tab.B[i]) .* cache.R[i]
231+
u .+= (dt * B[i]) .* R[i]
238232
end
239233
end
240234

src/solvers/lsrk.jl

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,18 @@ function init_cache(prob::DiffEqBase.ODEProblem, alg::LowStorageRungeKutta2N; kw
4040
return LowStorageRungeKutta2NIncCache(tableau(alg, eltype(du)), du)
4141
end
4242

43-
nstages(::LowStorageRungeKutta2NIncCache{N}) where {N} = N
43+
n_stages(::LowStorageRungeKutta2NIncCache{N}) where {N} = N
4444

4545
function step_u!(int, cache::LowStorageRungeKutta2NIncCache)
46-
tab = cache.tableau
46+
(; C, A, B) = cache.tableau
4747
du = cache.du
48+
(; u, p, t, dt) = int
4849

49-
u = int.u
50-
p = int.p
51-
t = int.t
52-
dt = int.dt
53-
54-
for stage in 1:nstages(cache)
55-
# du .= f(u, p, t + tab.C[stage]*dt) .+ tab.A[stage] .* du
56-
stage_time = t + tab.C[stage] * dt
57-
int.sol.prob.f(du, u, p, stage_time, 1, tab.A[stage])
58-
u .+= (dt * tab.B[stage]) .* du
50+
for stage in 1:n_stages(cache)
51+
# du .= f(u, p, t + C[stage]*dt) .+ A[stage] .* du
52+
stage_time = t + C[stage] * dt
53+
int.sol.prob.f(du, u, p, stage_time, 1, A[stage])
54+
u .+= (dt * B[stage]) .* du
5955
end
6056
end
6157

@@ -65,22 +61,23 @@ function init_inner(prob, outercache::LowStorageRungeKutta2NIncCache, dt)
6561
end
6662
function update_inner!(innerinteg, outercache::LowStorageRungeKutta2NIncCache, f_slow, u, p, t, dt, stage)
6763

64+
(; C, A, B) = cache.tableau
6865
f_offset = innerinteg.sol.prob.f
6966
tab = outercache.tableau
70-
N = nstages(outercache)
67+
N = n_stages(outercache)
7168

72-
τ0 = t + tab.C[stage] * dt
73-
τ1 = stage == N ? t + dt : t + tab.C[stage + 1] * dt
69+
τ0 = t + C[stage] * dt
70+
τ1 = stage == N ? t + dt : t + C[stage + 1] * dt
7471
f_offset.α = τ0
7572
innerinteg.t = zero(τ0)
7673
innerinteg.tstop = τ1 - τ0
7774

78-
# du .= f(u, p, t + tab.C[stage]*dt) .+ tab.A[stage] .* du
79-
f_slow(f_offset.x, u, p, τ0, 1, tab.A[stage])
75+
# du .= f(u, p, t + C[stage]*dt) .+ A[stage] .* du
76+
f_slow(f_offset.x, u, p, τ0, 1, A[stage])
8077

81-
C0 = tab.C[stage]
82-
C1 = stage == N ? one(tab.C[stage]) : tab.C[stage + 1]
83-
f_offset.γ = tab.B[stage] / (C1 - C0)
78+
C0 = C[stage]
79+
C1 = stage == N ? one(C[stage]) : C[stage + 1]
80+
f_offset.γ = B[stage] / (C1 - C0)
8481
end
8582

8683

src/solvers/mis.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ struct MultirateInfinitesimalStepCache{Nstages, A, T <: MultirateInfinitesimalSt
4949
tableau::T
5050
end
5151

52-
nstages(::MultirateInfinitesimalStepCache{Nstages}) where {Nstages} = Nstages
52+
n_stages(::MultirateInfinitesimalStepCache{Nstages}) where {Nstages} = Nstages
5353

5454
function init_cache(
5555
prob::DiffEqBase.AbstractODEProblem{uType, tType, true},
@@ -76,15 +76,15 @@ end
7676
function update_inner!(innerinteg, outercache::MultirateInfinitesimalStepCache, f_slow, u, p, t, dt, i)
7777

7878
f_offset = innerinteg.sol.prob.f
79-
tab = outercache.tableau
80-
N = nstages(outercache)
79+
N = n_stages(outercache)
80+
(; c, c̃, d) = outercache.tableau
8181

8282
F = outercache.F
8383
ΔU = outercache.ΔU
8484

8585
# F[i] = f_slow(U[i-1], p, t + c[i-1]*dt)
8686
u0 = i == 1 ? u : ΔU[i - 1]
87-
t0 = i == 1 ? t : t + tab.c[i - 1] * dt
87+
t0 = i == 1 ? t : t + c[i - 1] * dt
8888
f_slow(F[i], u0, p, t0)
8989

9090
# the (i+1)th stage of the paper
@@ -109,16 +109,17 @@ function update_inner!(innerinteg, outercache::MultirateInfinitesimalStepCache,
109109

110110
# KW2014 (9)
111111
# evaluate f_fast(z(τ), p, t + c̃[i]*dt + (c[i]-c̃[i])/d[i] * τ)
112-
f_offset.α = t + tab.c̃[i] * dt
113-
f_offset.β = (tab.c[i] - tab.c̃[i]) / tab.d[i]
112+
f_offset.α = t + c̃[i] * dt
113+
f_offset.β = (c[i] - c̃[i]) / d[i]
114114

115115
innerinteg.t = zero(t)
116-
innerinteg.tstop = tab.d[i] * dt
116+
innerinteg.tstop = d[i] * dt
117117
end
118118

119119
@kernel function mis_update!(u, ΔU, F, innerinteg_u, f_offset_x, tab, i, N, dt)
120120
e = @index(Global, Linear)
121121
@inbounds begin
122+
(; α, β, d, γ) = tab
122123
if i > 1
123124
ΔU[i - 1][e] -= u[e]
124125
end
@@ -128,13 +129,13 @@ end
128129
innerinteg_u[e] = u[e]
129130
end
130131
for j in 1:(i - 1)
131-
innerinteg_u[e] += tab.α[i, j] * ΔU[j][e]
132+
innerinteg_u[e] += α[i, j] * ΔU[j][e]
132133
end
133134

134135
# KW2014 (1b) / (9)
135-
f_offset_x[e] = tab.β[i, i] / tab.d[i] .* F[i][e]
136+
f_offset_x[e] = β[i, i] / d[i] .* F[i][e]
136137
for j in 1:(i - 1)
137-
f_offset_x[e] += (tab.γ[i, j] / (tab.d[i] * dt)) * ΔU[j][e] + tab.β[i, j] / tab.d[i] * F[j][e]
138+
f_offset_x[e] += (γ[i, j] / (d[i] * dt)) * ΔU[j][e] + β[i, j] / d[i] * F[j][e]
138139
end
139140
end
140141
end

src/solvers/multirate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function step_u!(int, cache::MultirateCache)
5151
innerinteg = cache.innerinteg
5252
fast_dt = innerinteg.dt
5353

54-
N = nstages(outercache)
54+
N = n_stages(outercache)
5555
for stage in 1:N
5656

5757
update_inner!(innerinteg, outercache, int.sol.prob.f.f2, u, p, t, dt, stage)

src/solvers/rosenbrock.jl

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,39 +33,30 @@ end
3333

3434

3535
function step_u!(int, cache::RosenbrockCache{Nstages, RT}) where {Nstages, RT}
36-
tab = cache.tableau
37-
36+
(; m, Γ, A, C) = cache.tableau
37+
(; u, p, t, dt) = int
38+
(; W, U, fU, k, linsolve!) = cache
3839
f! = int.sol.prob.f
3940
Wfact_t! = int.sol.prob.f.Wfact_t
4041

41-
u = int.u
42-
p = int.p
43-
t = int.t
44-
dt = int.dt
45-
W = cache.W
46-
U = cache.U
47-
fU = cache.fU
48-
k = cache.k
49-
linsolve! = cache.linsolve!
50-
5142
# 1) compute jacobian factorization
52-
γ = dt * tab.Γ[1, 1]
43+
γ = dt * Γ[1, 1]
5344
Wfact_t!(W, u, p, γ, t)
5445
for i in 1:Nstages
5546
U .= u
5647
for j in 1:(i - 1)
57-
U .+= tab.A[i, j] .* k[j]
48+
U .+= A[i, j] .* k[j]
5849
end
5950
# TODO: there should be a time modification here (t + c * dt)
6051
# if f does depend on time, would need to add tgrad term as well
6152
f!(fU, U, p, t)
6253
for j in 1:(i - 1)
63-
fU .+= (tab.C[i, j] / dt) .* k[j]
54+
fU .+= (C[i, j] / dt) .* k[j]
6455
end
6556
linsolve!(k[i], W, fU)
6657
end
6758
for i in 1:Nstages
68-
u .+= tab.m[i] .* k[i]
59+
u .+= m[i] .* k[i]
6960
end
7061
end
7162

0 commit comments

Comments
 (0)