Skip to content

Commit 6dfddd2

Browse files
committed
Reformulate IMEX ARK to correctly account for Newton residuals and DSS
1 parent f2e2b71 commit 6dfddd2

File tree

4 files changed

+371
-10
lines changed

4 files changed

+371
-10
lines changed

src/ClimaTimeSteppers.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ array_device(x) = CUDADevice() # assume CUDA
5959

6060
import DiffEqBase, SciMLBase, LinearAlgebra, DiffEqCallbacks, Krylov
6161

62+
include(joinpath("utilities", "sparse_tuple.jl"))
6263
include(joinpath("utilities", "sparse_coeffs.jl"))
6364
include(joinpath("utilities", "fused_increment.jl"))
6465
include("sparse_containers.jl")
@@ -118,7 +119,7 @@ const SPCO = SparseCoeffs
118119

119120
include("solvers/imex_tableaus.jl")
120121
include("solvers/explicit_tableaus.jl")
121-
include("solvers/imex_ark.jl")
122+
include("solvers/imex_ark_new.jl")
122123
include("solvers/imex_ssprk.jl")
123124
include("solvers/multirate.jl")
124125
include("solvers/lsrk.jl")

src/solvers/imex_ark_new.jl

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
has_jac(T_imp!) =
2+
hasfield(typeof(T_imp!), :Wfact) &&
3+
hasfield(typeof(T_imp!), :jac_prototype) &&
4+
!isnothing(T_imp!.Wfact) &&
5+
!isnothing(T_imp!.jac_prototype)
6+
7+
imp_error(name) = error("$(isnothing(name) ? "The given IMEXTableau" : name) \
8+
has implicit stages that require a nonlinear solver, \
9+
so NewtonsMethod must be specified alongside T_imp!.")
10+
11+
sdirk_error(name) = error("$(isnothing(name) ? "The given IMEXTableau" : name) \
12+
has implicit stages with distinct coefficients (it \
13+
is not SDIRK), and an update is required whenever a \
14+
stage has a different coefficient from the previous \
15+
stage. Do not update on the NewTimeStep signal when \
16+
using $(isnothing(name) ? "this tableau" : name).")
17+
18+
struct IMEXARKCache{S, N}
19+
stage_cache::S
20+
newtons_method_cache::N
21+
end
22+
23+
function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{Unconstrained}; kwargs...)
24+
(; u0) = prob
25+
(; T_lim!, T_exp!, T_exp_T_lim!, T_imp!) = prob.f
26+
(; name, newtons_method) = alg
27+
(; a_exp, b_exp, a_imp, b_imp) = alg.tableau
28+
29+
no_T_lim = isnothing(T_lim!) && isnothing(T_exp_T_lim!)
30+
no_T_exp = isnothing(T_exp!) && isnothing(T_exp_T_lim!)
31+
no_T_imp = isnothing(T_imp!)
32+
33+
s = length(a_imp, 1) # number of stages
34+
35+
A_exp = vcat(a_exp, b_exp') # exp coefs with final state seen as stage s + 1
36+
A_imp = vcat(a_imp, b_imp') # imp coefs with final state seen as stage s + 1
37+
Γ = a_imp[nz_stages, nz_stages] # "fully implicit" part of a_imp
38+
sdirk_γ = length(unique(diag(Γ))) == 1 ? diag(Γ)[1] : nothing
39+
40+
z_stages = findall(iszero, diag(a_imp)) # stages without implicit solves
41+
nz_stages = findall(!iszero, diag(a_imp)) # stages with implicit solves
42+
exp_stages = findall(col -> any(!iszero, col), eachcol(A_exp))
43+
imp_z_stages = findall(col -> any(!iszero, col), eachcol(A_imp[:, z_stages]))
44+
45+
temp_value = similar(u0)
46+
T_lim_by_stage = no_T_lim ? SparseTuple() : SparseTuple(map(_ -> similar(u0), exp_stages), exp_stages)
47+
T_exp_by_stage = no_T_exp ? SparseTuple() : SparseTuple(map(_ -> similar(u0), exp_stages), exp_stages)
48+
T_imp_by_stage = no_T_imp ? SparseTuple() : SparseTuple(map(_ -> similar(u0), imp_z_stages), imp_z_stages)
49+
ΔU_imp_by_stage = no_T_imp ? SparseTuple() : SparseTuple(map(_ -> similar(u0), nz_stages), nz_stages)
50+
51+
isnothing(newtons_method) && !isempty(ΔU_imp_by_stage) && imp_error(name)
52+
53+
prev_ΔtT_lim_to_Δu_coefs_by_stage = no_T_lim ? SparseTuple() : sparse_matrix_rows(A_exp, 1:(s + 1), 1:s)
54+
prev_ΔtT_exp_to_Δu_coefs_by_stage = no_T_exp ? SparseTuple() : sparse_matrix_rows(A_exp, 1:(s + 1), 1:s)
55+
prev_z_ΔtT_imp_to_Δu_coefs_by_stage =
56+
no_T_imp ? SparseTuple() : sparse_matrix_rows(A_imp[:, z_stages], 1:(s + 1), z_stages)
57+
58+
prev_nz_ΔtT_imp_to_Δu_coef_matrix = A_imp[:, nz_stages]
59+
prev_nz_ΔtT_imp_to_Δu_coef_matrix[nz_stages, nz_stages] .= Γ - Diagonal(Γ)
60+
prev_nz_ΔU_imp_to_Δu_coefs_by_stage =
61+
no_T_imp ? SparseTuple() : sparse_matrix_rows(prev_nz_ΔtT_imp_to_Δu_coef_matrix * inv(Γ), 1:(s + 1), nz_stages)
62+
63+
# Convert all inputs to unrolled_foreach in step_u! into regular tuples.
64+
T_lim_by_stage_dense = dense_tuple(T_lim_by_stage, s, nothing)
65+
T_exp_by_stage_dense = dense_tuple(T_exp_by_stage, s, nothing)
66+
T_imp_by_stage_dense = dense_tuple(T_imp_by_stage, s, nothing)
67+
ΔU_imp_by_stage_dense = dense_tuple(ΔU_imp_by_stage, s, nothing)
68+
prev_ΔtT_lim_to_Δu_coefs_by_stage_dense = dense_tuple(prev_ΔtT_lim_to_Δu_coefs_by_stage, s, SparseTuple())
69+
prev_ΔtT_exp_to_Δu_coefs_by_stage_dense = dense_tuple(prev_ΔtT_exp_to_Δu_coefs_by_stage, s, SparseTuple())
70+
prev_z_ΔtT_imp_to_Δu_coefs_by_stage_dense = dense_tuple(prev_z_ΔtT_imp_to_Δu_coefs_by_stage, s, SparseTuple())
71+
prev_nz_ΔU_imp_to_Δu_coefs_by_stage_dense = dense_tuple(prev_nz_ΔU_imp_to_Δu_coefs_by_stage, s, SparseTuple())
72+
73+
stage_cache = (;
74+
s,
75+
sdirk_γ,
76+
temp_value,
77+
T_lim_by_stage,
78+
T_exp_by_stage,
79+
T_imp_by_stage,
80+
ΔU_imp_by_stage,
81+
T_lim_by_stage_dense,
82+
T_exp_by_stage_dense,
83+
T_imp_by_stage_dense,
84+
ΔU_imp_by_stage_dense,
85+
prev_ΔtT_lim_to_Δu_coefs_by_stage_dense,
86+
prev_ΔtT_exp_to_Δu_coefs_by_stage_dense,
87+
prev_z_ΔtT_imp_to_Δu_coefs_by_stage_dense,
88+
prev_nz_ΔU_imp_to_Δu_coefs_by_stage_dense,
89+
)
90+
91+
newtons_method_cache =
92+
isnothing(newtons_method) ? nothing :
93+
allocate_cache(newtons_method, u0, has_jac(T_imp!) ? T_imp!.jac_prototype : nothing)
94+
95+
return IMEXARKCache(stage_cache, newtons_method_cache)
96+
end
97+
98+
function step_u!(integrator, cache::IMEXARKCache)
99+
(; u, p, t, alg) = integrator
100+
(; T_lim!, T_exp!, T_exp_T_lim!, T_imp!) = integrator.sol.prob.f
101+
(; lim!, dss!, post_explicit!, post_implicit!) = integrator.sol.prob.f
102+
(; name, newtons_method) = alg
103+
(; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = alg.tableau
104+
(; newtons_method_cache) = cache
105+
(;
106+
s,
107+
sdirk_γ,
108+
temp_value,
109+
T_lim_by_stage,
110+
T_exp_by_stage,
111+
T_imp_by_stage,
112+
ΔU_imp_by_stage,
113+
T_lim_by_stage_dense,
114+
T_exp_by_stage_dense,
115+
T_imp_by_stage_dense,
116+
ΔU_imp_by_stage_dense,
117+
prev_ΔtT_lim_to_Δu_coefs_by_stage_dense,
118+
prev_ΔtT_exp_to_Δu_coefs_by_stage_dense,
119+
prev_z_ΔtT_imp_to_Δu_coefs_by_stage_dense,
120+
prev_nz_ΔU_imp_to_Δu_coefs_by_stage_dense,
121+
) = cache.stage_cache
122+
123+
Δt = integrator.dt
124+
125+
if !isnothing(T_imp!) && !isnothing(newtons_method)
126+
(; update_j) = newtons_method
127+
(; j) = newtons_method_cache
128+
if !isnothing(j) && needs_update!(update_j, NewTimeStep(t))
129+
isnothing(sdirk_γ) && sdirk_error(name)
130+
T_imp!.Wfact(j, u, p, Δt * sdirk_γ, t)
131+
end
132+
end
133+
134+
unrolled_foreach(
135+
1:s,
136+
T_lim_by_stage_dense[1:s],
137+
T_exp_by_stage_dense[1:s],
138+
T_imp_by_stage_dense[1:s],
139+
ΔU_imp_by_stage_dense[1:s],
140+
prev_ΔtT_lim_to_Δu_coefs_by_stage_dense[1:s],
141+
prev_ΔtT_exp_to_Δu_coefs_by_stage_dense[1:s],
142+
prev_z_ΔtT_imp_to_Δu_coefs_by_stage_dense[1:s],
143+
prev_nz_ΔU_imp_to_Δu_coefs_by_stage_dense[1:s],
144+
) do (
145+
stage,
146+
T_lim,
147+
T_exp,
148+
T_imp,
149+
ΔU_imp,
150+
prev_ΔtT_lim_to_Δu_coefs,
151+
prev_ΔtT_exp_to_Δu_coefs,
152+
prev_z_ΔtT_imp_to_Δu_coefs,
153+
prev_nz_ΔU_imp_to_Δu_coefs,
154+
)
155+
t_exp = t + Δt * c_exp[stage]
156+
t_imp = t + Δt * c_imp[stage]
157+
Δtγ = Δt * a_imp[stage, stage]
158+
159+
Δu_over_Δt_from_prev_T_lim = broadcasted_dot(prev_ΔtT_lim_to_Δu_coefs, T_lim_by_stage)
160+
Δu_over_Δt_from_prev_T_exp = broadcasted_dot(prev_ΔtT_exp_to_Δu_coefs, T_exp_by_stage)
161+
Δu_over_Δt_from_prev_T_imp = broadcasted_dot(prev_z_ΔtT_imp_to_Δu_coefs, T_imp_by_stage)
162+
Δu_from_prev_ΔU_imp = broadcasted_dot(prev_nz_ΔU_imp_to_Δu_coefs, ΔU_imp_by_stage)
163+
164+
if isempty(prev_ΔtT_lim_to_Δu_coefs)
165+
u_plus_Δu_lim = u
166+
else
167+
u_plus_Δu_lim = temp_value
168+
@. u_plus_Δu_lim = u + Δt * Δu_over_Δt_from_prev_T_lim
169+
lim!(u_plus_Δu_lim, p, t_exp, u)
170+
end
171+
172+
if (
173+
isempty(prev_ΔtT_exp_to_Δu_coefs) &&
174+
isempty(prev_z_ΔtT_imp_to_Δu_coefs) &&
175+
isempty(prev_nz_ΔU_imp_to_Δu_coefs)
176+
)
177+
U_before_solve = u_plus_Δu_lim
178+
else
179+
U_before_solve = temp_value
180+
@. U_before_solve =
181+
u_plus_Δu_lim + Δt * (Δu_over_Δt_from_prev_T_exp + Δu_over_Δt_from_prev_T_imp) + Δu_from_prev_ΔU_imp
182+
end
183+
184+
post_explicit!(U_before_solve, p, t_imp)
185+
186+
if !isnothing(ΔU_imp)
187+
# Use ΔU_imp as additional temporary storage, since its value does
188+
# not need to be set until the end of each stage.
189+
U = ΔU_imp
190+
@. U = U_before_solve
191+
192+
# Solve U ≈ U_before_solve + Δtγ * T_imp(U, p, t_imp) for U.
193+
set_residual! = (residual, U) -> begin
194+
T_imp!(residual, U, p, t_imp)
195+
@. residual = U_before_solve - U + Δtγ * residual
196+
end
197+
set_jacobian! = (jacobian, U) -> T_imp!.Wfact(jacobian, U, p, Δtγ, t_imp)
198+
post_implicit! = U -> post_implicit!(U, p, t_imp)
199+
solve_newton!(
200+
newtons_method,
201+
newtons_method_cache,
202+
U,
203+
set_residual!,
204+
set_jacobian!,
205+
post_implicit!,
206+
post_implicit!,
207+
)
208+
else
209+
U = U_before_solve # There is no solve on this stage.
210+
end
211+
212+
if !isnothing(T_lim) || !isnothing(T_exp)
213+
if !isnothing(T_exp_T_lim!)
214+
T_exp_T_lim!(T_exp, T_lim, U, p, t_exp)
215+
dss!(T_lim, p, t_exp)
216+
dss!(T_exp, p, t_exp)
217+
end
218+
if !isnothing(T_lim!)
219+
T_lim!(T_lim, U, p, t_exp)
220+
dss!(T_lim, p, t_exp)
221+
end
222+
if !isnothing(T_exp!)
223+
T_exp!(T_exp, U, p, t_exp)
224+
dss!(T_exp, p, t_exp)
225+
end
226+
# TODO: Can we just use T_exp_T_lim!, and not 3 different functions?
227+
# TODO: Fuse the DSS calls above into a single operation.
228+
end
229+
230+
if !isnothing(ΔU_imp)
231+
# TODO: Subtract the T_lim contribution from this term.
232+
@. ΔU_imp = U - u - Δt * Δu_over_Δt_from_prev_T_exp
233+
elseif !isnothing(T_imp)
234+
T_imp!(T_imp, U, p, t_imp)
235+
end
236+
end
237+
238+
t_final = t + Δt
239+
prev_ΔtT_lim_to_Δu_coefs_final = prev_ΔtT_lim_to_Δu_coefs_by_stage_dense[s + 1]
240+
prev_ΔtT_exp_to_Δu_coefs_final = prev_ΔtT_exp_to_Δu_coefs_by_stage_dense[s + 1]
241+
prev_z_ΔtT_imp_to_Δu_coefs_final = prev_z_ΔtT_imp_to_Δu_coefs_by_stage_dense[s + 1]
242+
prev_nz_ΔU_imp_to_Δu_coefs_final = prev_nz_ΔU_imp_to_Δu_coefs_by_stage_dense[s + 1]
243+
Δu_over_Δt_from_prev_T_lim_final = broadcasted_dot(prev_ΔtT_lim_to_Δu_coefs_final, T_lim_by_stage)
244+
Δu_over_Δt_from_prev_T_exp_final = broadcasted_dot(prev_ΔtT_exp_to_Δu_coefs_final, T_exp_by_stage)
245+
Δu_over_Δt_from_prev_T_imp_final = broadcasted_dot(prev_z_ΔtT_imp_to_Δu_coefs_final, T_imp_by_stage)
246+
Δu_from_prev_ΔU_imp_final = broadcasted_dot(prev_nz_ΔU_imp_to_Δu_coefs_final, ΔU_imp_by_stage)
247+
248+
if isempty(prev_ΔtT_lim_to_Δu_coefs_final)
249+
u_plus_Δu_lim = u
250+
else
251+
u_plus_Δu_lim = temp_value
252+
@. u_plus_Δu_lim = u + Δt * Δu_over_Δt_from_prev_T_lim_final
253+
lim!(u_plus_Δu_lim, p, t_final, u)
254+
end
255+
256+
@. u =
257+
u_plus_Δu_lim +
258+
Δt * (Δu_over_Δt_from_prev_T_exp_final + Δu_over_Δt_from_prev_T_imp_final) +
259+
Δu_from_prev_ΔU_imp_final
260+
dss!(u, p, t_final)
261+
post_explicit!(u, p, t_final)
262+
263+
return u
264+
end

src/solvers/imex_tableaus.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@ default values for `c_exp` and `c_imp` assume that it is internally consistent.
1919
The explicit tableau must be strictly lower triangular, and the implicit tableau
2020
must be lower triangular (only DIRK algorithms are currently supported).
2121
"""
22-
struct IMEXTableau{AE <: SPCO, BE <: SPCO, CE <: SPCO, AI <: SPCO, BI <: SPCO, CI <: SPCO}
23-
a_exp::AE # matrix of size s×s
24-
b_exp::BE # vector of length s
25-
c_exp::CE # vector of length s
26-
a_imp::AI # matrix of size s×s
27-
b_imp::BI # vector of length s
28-
c_imp::CI # vector of length s
29-
end
30-
IMEXTableau(args...) = IMEXTableau(map(x -> SparseCoeffs(x), args)...)
22+
struct IMEXTableau{M, V}
23+
a_exp::M # matrix of size s×s
24+
b_exp::V # vector of length s
25+
c_exp::V # vector of length s
26+
a_imp::M # matrix of size s×s
27+
b_imp::V # vector of length s
28+
c_imp::V # vector of length s
29+
end
3130

3231
function IMEXTableau(;
3332
a_exp,

0 commit comments

Comments
 (0)