Skip to content

Commit 5418a44

Browse files
committed
Reformulate IMEX ARK to correctly account for Newton residuals and DSS
1 parent f2e2b71 commit 5418a44

File tree

4 files changed

+396
-10
lines changed

4 files changed

+396
-10
lines changed

src/ClimaTimeSteppers.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ array_device(x) = CUDADevice() # assume CUDA
5959

6060
import DiffEqBase, SciMLBase, LinearAlgebra, DiffEqCallbacks, Krylov
6161

62+
include(joinpath("utilities", "sparse_tuple.jl"))
6263
include(joinpath("utilities", "sparse_coeffs.jl"))
6364
include(joinpath("utilities", "fused_increment.jl"))
6465
include("sparse_containers.jl")
@@ -118,7 +119,7 @@ const SPCO = SparseCoeffs
118119

119120
include("solvers/imex_tableaus.jl")
120121
include("solvers/explicit_tableaus.jl")
121-
include("solvers/imex_ark.jl")
122+
include("solvers/imex_ark_new.jl")
122123
include("solvers/imex_ssprk.jl")
123124
include("solvers/multirate.jl")
124125
include("solvers/lsrk.jl")

src/solvers/imex_ark_new.jl

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
has_jac(T_imp!) =
2+
hasfield(typeof(T_imp!), :Wfact) &&
3+
hasfield(typeof(T_imp!), :jac_prototype) &&
4+
!isnothing(T_imp!.Wfact) &&
5+
!isnothing(T_imp!.jac_prototype)
6+
7+
imp_error(name) = error("$(isnothing(name) ? "The given IMEXTableau" : name) \
8+
has implicit stages that require a nonlinear solver, \
9+
so NewtonsMethod must be specified alongside T_imp!.")
10+
11+
sdirk_error(name) = error("$(isnothing(name) ? "The given IMEXTableau" : name) \
12+
has implicit stages with distinct coefficients (it \
13+
is not SDIRK), and an update is required whenever a \
14+
stage has a different coefficient from the previous \
15+
stage. Do not update on the NewTimeStep signal when \
16+
using $(isnothing(name) ? "this tableau" : name).")
17+
18+
struct IMEXARKCache{T, N}
19+
timestepper_cache::T
20+
newtons_method_cache::N
21+
end
22+
23+
function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{Unconstrained}; kwargs...)
24+
(; u0) = prob
25+
(; T_lim!, T_exp!, T_exp_T_lim!, T_imp!) = prob.f
26+
(; name, newtons_method) = alg
27+
(; a_exp, b_exp, a_imp, b_imp) = alg.tableau
28+
29+
no_T_lim = isnothing(T_lim!) && isnothing(T_exp_T_lim!)
30+
no_T_exp = isnothing(T_exp!) && isnothing(T_exp_T_lim!)
31+
no_T_imp = isnothing(T_imp!)
32+
33+
s = size(a_imp, 1) # number of internal stages
34+
35+
# Extend the coefficient matrices a_exp and a_imp by interpreting the final
36+
# state on each step as stage s + 1.
37+
A_exp = vcat(a_exp, b_exp')
38+
A_imp = vcat(a_imp, b_imp')
39+
40+
z_stages = findall(iszero, diag(A_imp)) # stages without implicit solves
41+
nz_stages = findall(!iszero, diag(A_imp)) # stages with implicit solves
42+
stages_needing_T_exp = findall(col -> any(!iszero, col), eachcol(A_exp))
43+
z_stages_needing_T_imp = findall(col -> any(!iszero, col), eachcol(A_imp[:, z_stages]))
44+
# All nz stages are computed using ΔU_imp, rather than T_imp.
45+
46+
Γ = A_imp[nz_stages, nz_stages] # "fully implicit" part of A_imp
47+
sdirk_γ = length(unique(diag(Γ))) == 1 ? diag(Γ)[1] : nothing
48+
49+
temp_value1 = similar(u0)
50+
temp_value2 = similar(u0)
51+
T_lim_values_sparse = no_T_lim ? SparseTuple() : SparseTuple(_ -> similar(u0), stages_needing_T_exp)
52+
T_exp_values_sparse = no_T_exp ? SparseTuple() : SparseTuple(_ -> similar(u0), stages_needing_T_exp)
53+
T_imp_values_sparse = no_T_imp ? SparseTuple() : SparseTuple(_ -> similar(u0), z_stages_needing_T_imp)
54+
ΔU_imp_values_sparse = no_T_imp ? SparseTuple() : SparseTuple(_ -> similar(u0), nz_stages)
55+
56+
ΔtT_lim_to_Δu_lim_tuples_sparse = no_T_lim ? SparseTuple() : sparse_matrix_rows(A_exp, 1:(s + 1), 1:s)
57+
ΔtT_exp_to_Δu_exp_tuples_sparse = no_T_exp ? SparseTuple() : sparse_matrix_rows(A_exp, 1:(s + 1), 1:s)
58+
ΔtT_imp_to_Δu_imp_tuples_sparse =
59+
no_T_imp ? SparseTuple() : sparse_matrix_rows(A_imp[:, z_stages], 1:(s + 1), z_stages)
60+
61+
A_imp_lower_nz_component = A_imp[:, nz_stages]
62+
A_imp_lower_nz_component[nz_stages, nz_stages] .= Γ - Diagonal(Γ)
63+
prev_ΔU_imp_to_Δu_imp_tuples_sparse =
64+
no_T_imp ? SparseTuple() : sparse_matrix_rows(A_imp_lower_nz_component * inv(Γ), 1:(s + 1), nz_stages)
65+
66+
# Convert all values that will be passed to unrolled_foreach in step_u! into
67+
# tuples of length s.
68+
T_lim_values = dense_tuple(T_lim_values_sparse, s, nothing)
69+
T_exp_values = dense_tuple(T_exp_values_sparse, s, nothing)
70+
T_imp_values = dense_tuple(T_imp_values_sparse, s, nothing)
71+
ΔU_imp_values = dense_tuple(ΔU_imp_values_sparse, s, nothing)
72+
ΔtT_lim_to_Δu_lim_tuples = dense_tuple(ΔtT_lim_to_Δu_lim_tuples_sparse, s, SparseTuple())
73+
ΔtT_exp_to_Δu_exp_tuples = dense_tuple(ΔtT_exp_to_Δu_exp_tuples_sparse, s, SparseTuple())
74+
ΔtT_imp_to_Δu_imp_tuples = dense_tuple(ΔtT_imp_to_Δu_imp_tuples_sparse, s, SparseTuple())
75+
prev_ΔU_imp_to_Δu_imp_tuples = dense_tuple(prev_ΔU_imp_to_Δu_imp_tuples_sparse, s, SparseTuple())
76+
77+
timestepper_cache = (;
78+
sdirk_γ,
79+
temp_value1,
80+
temp_value2,
81+
T_lim_values_sparse,
82+
T_exp_values_sparse,
83+
T_imp_values_sparse,
84+
ΔU_imp_values_sparse,
85+
T_lim_values,
86+
T_exp_values,
87+
T_imp_values,
88+
ΔU_imp_values,
89+
ΔtT_lim_to_Δu_lim_tuples,
90+
ΔtT_exp_to_Δu_exp_tuples,
91+
ΔtT_imp_to_Δu_imp_tuples,
92+
prev_ΔU_imp_to_Δu_imp_tuples,
93+
)
94+
95+
newtons_method_cache = if is_accessible(ΔU_imp_values_sparse)
96+
isnothing(newtons_method) && imp_error(name)
97+
j = has_jac(T_imp!) ? T_imp!.jac_prototype : nothing
98+
allocate_cache(newtons_method, u0, j)
99+
else
100+
nothing
101+
end
102+
103+
return IMEXARKCache(timestepper_cache, newtons_method_cache)
104+
end
105+
106+
function step_u!(integrator, cache::IMEXARKCache)
107+
(; u, p, t, alg) = integrator
108+
(; T_lim!, T_exp!, T_exp_T_lim!, T_imp!) = integrator.sol.prob.f
109+
(; lim!, dss!, post_explicit!, post_implicit!) = integrator.sol.prob.f
110+
(; name, newtons_method) = alg
111+
(; a_imp, c_exp, c_imp) = alg.tableau
112+
(; newtons_method_cache) = cache
113+
(;
114+
sdirk_γ,
115+
temp_value1,
116+
temp_value2,
117+
T_lim_values_sparse,
118+
T_exp_values_sparse,
119+
T_imp_values_sparse,
120+
ΔU_imp_values_sparse,
121+
T_lim_values,
122+
T_exp_values,
123+
T_imp_values,
124+
ΔU_imp_values,
125+
ΔtT_lim_to_Δu_lim_tuples,
126+
ΔtT_exp_to_Δu_exp_tuples,
127+
ΔtT_imp_to_Δu_imp_tuples,
128+
prev_ΔU_imp_to_Δu_imp_tuples,
129+
) = cache.timestepper_cache
130+
131+
Δt = integrator.dt
132+
s = size(a_imp, 1)
133+
134+
if !isnothing(newtons_method_cache)
135+
(; update_j) = newtons_method
136+
(; j) = newtons_method_cache
137+
if !isnothing(j) && needs_update!(update_j, NewTimeStep(t))
138+
isnothing(sdirk_γ) && sdirk_error(name)
139+
T_imp!.Wfact(j, u, p, Δt * sdirk_γ, t)
140+
end
141+
end
142+
143+
unrolled_foreach(
144+
ntuple(identity, s),
145+
T_lim_values,
146+
T_exp_values,
147+
T_imp_values,
148+
ΔU_imp_values,
149+
ΔtT_lim_to_Δu_lim_tuples,
150+
ΔtT_exp_to_Δu_exp_tuples,
151+
ΔtT_imp_to_Δu_imp_tuples,
152+
prev_ΔU_imp_to_Δu_imp_tuples,
153+
) do (
154+
stage,
155+
T_lim,
156+
T_exp,
157+
T_imp,
158+
ΔU_imp,
159+
ΔtT_lim_to_ΔU_lim,
160+
ΔtT_exp_to_ΔU_exp,
161+
ΔtT_imp_to_ΔU_imp,
162+
prev_ΔU_imp_to_ΔU_imp,
163+
)
164+
t_exp = t + Δt * c_exp[stage]
165+
t_imp = t + Δt * c_imp[stage]
166+
Δtγ = Δt * a_imp[stage, stage]
167+
168+
ΔU_lim_over_Δt = sparse_broadcasted_dot(ΔtT_lim_to_ΔU_lim, T_lim_values_sparse)
169+
ΔU_exp_over_Δt = sparse_broadcasted_dot(ΔtT_exp_to_ΔU_exp, T_exp_values_sparse)
170+
ΔU_imp_from_T_imp_over_Δt = sparse_broadcasted_dot(ΔtT_imp_to_ΔU_imp, T_imp_values_sparse)
171+
ΔU_imp_from_prev_ΔU_imp = sparse_broadcasted_dot(prev_ΔU_imp_to_ΔU_imp, ΔU_imp_values_sparse)
172+
173+
if is_accessible(ΔtT_lim_to_ΔU_lim)
174+
u_plus_ΔU_lim = temp_value1
175+
@. u_plus_ΔU_lim = u + Δt * ΔU_lim_over_Δt
176+
lim!(u_plus_ΔU_lim, p, t_exp, u)
177+
else
178+
u_plus_ΔU_lim = u
179+
end
180+
181+
if is_accessible(ΔtT_exp_to_ΔU_exp) || is_accessible(ΔtT_imp_to_ΔU_imp) || is_accessible(prev_ΔU_imp_to_ΔU_imp)
182+
U_before_solve = temp_value2
183+
@. U_before_solve =
184+
u_plus_ΔU_lim + Δt * (ΔU_exp_over_Δt + ΔU_imp_from_T_imp_over_Δt) + ΔU_imp_from_prev_ΔU_imp
185+
else
186+
U_before_solve = u_plus_ΔU_lim
187+
end
188+
189+
is_not_u_before_solve =
190+
is_accessible(ΔtT_lim_to_ΔU_lim) ||
191+
is_accessible(ΔtT_exp_to_ΔU_exp) ||
192+
is_accessible(ΔtT_imp_to_ΔU_imp) ||
193+
is_accessible(prev_ΔU_imp_to_ΔU_imp)
194+
195+
# TODO: Rename post_explicit! to pre_newton_iteration!, and rename
196+
# post_implicit! to post_stage!. Make pre_newton_iteration! only set
197+
# precomputed quantities needed by T_imp! and T_imp!.Wfact. Keep
198+
# post_stage! as it is now, so that it sets all precomputed quantities.
199+
200+
if !isnothing(ΔU_imp)
201+
is_not_u_before_solve && post_explicit!(U_before_solve, p, t_imp)
202+
203+
U = ΔU_imp # Use ΔU_imp as additional temporary storage.
204+
@. U = U_before_solve
205+
206+
# Solve U ≈ U_before_solve + Δtγ * T_imp(U, p, t_imp) for U.
207+
solve_newton!(
208+
newtons_method,
209+
newtons_method_cache,
210+
U,
211+
(residual, U) -> begin
212+
T_imp!(residual, U, p, t_imp)
213+
@. residual = U_before_solve - U + Δtγ * residual
214+
end,
215+
(j, U) -> T_imp!.Wfact(j, U, p, Δtγ, t_imp), # j = ∂residual/∂U
216+
U -> post_explicit!(U, p, t_imp),
217+
U -> post_implicit!(U, p, t_imp),
218+
)
219+
else
220+
U = U_before_solve # There is no solve on this stage.
221+
is_not_u_before_solve && post_implicit!(U, p, t_imp)
222+
end
223+
224+
if !isnothing(T_lim) || !isnothing(T_exp)
225+
if !isnothing(T_exp_T_lim!)
226+
T_exp_T_lim!(T_exp, T_lim, U, p, t_exp)
227+
if stage != s
228+
# TODO: Fuse these two DSS calls into one.
229+
dss!(T_lim, p, t_exp)
230+
dss!(T_exp, p, t_exp)
231+
end
232+
end
233+
# TODO: Drop support for specifying T_lim! separately from T_exp!.
234+
if !isnothing(T_lim!)
235+
T_lim!(T_lim, U, p, t_exp)
236+
stage != s && dss!(T_lim, p, t_exp)
237+
end
238+
if !isnothing(T_exp!)
239+
T_exp!(T_exp, U, p, t_exp)
240+
stage != s && dss!(T_exp, p, t_exp)
241+
end
242+
end
243+
if !isnothing(T_imp)
244+
T_imp!(T_imp, U, p, t_imp)
245+
end
246+
if !isnothing(ΔU_imp)
247+
@. ΔU_imp = U - u_plus_ΔU_lim - Δt * ΔU_exp_over_Δt
248+
# = U - U_before_solve + Δt * ΔU_imp_from_T_imp_over_Δt +
249+
# ΔU_imp_from_prev_ΔU_imp
250+
end
251+
end
252+
253+
t_final = t + Δt
254+
ΔtT_lim_to_final_Δu_lim = ΔtT_lim_to_Δu_lim_tuples[s + 1]
255+
ΔtT_exp_to_final_Δu_exp = ΔtT_exp_to_Δu_exp_tuples[s + 1]
256+
ΔtT_imp_to_final_Δu_imp = ΔtT_imp_to_Δu_imp_tuples[s + 1]
257+
ΔU_imp_to_final_Δu_imp = prev_ΔU_imp_to_Δu_imp_tuples[s + 1]
258+
final_Δu_lim_over_Δt = sparse_broadcasted_dot(ΔtT_lim_to_final_Δu_lim, T_lim_values_sparse)
259+
final_Δu_exp_over_Δt = sparse_broadcasted_dot(ΔtT_exp_to_final_Δu_exp, T_exp_values_sparse)
260+
final_Δu_imp_from_T_imp_over_Δt = sparse_broadcasted_dot(ΔtT_imp_to_final_Δu_imp, T_imp_values_sparse)
261+
final_Δu_imp_from_ΔU_imp = sparse_broadcasted_dot(ΔU_imp_to_final_Δu_imp, ΔU_imp_values_sparse)
262+
263+
if is_accessible(ΔtT_lim_to_final_Δu_lim)
264+
final_u_plus_Δu_lim = temp_value
265+
@. final_u_plus_Δu_lim = u + Δt * final_Δu_lim_over_Δt
266+
lim!(final_u_plus_Δu_lim, p, t_final, u)
267+
else
268+
final_u_plus_Δu_lim = u
269+
end
270+
271+
@. u =
272+
final_u_plus_Δu_lim + Δt * (final_Δu_exp_over_Δt + final_Δu_imp_from_T_imp_over_Δt) + final_Δu_imp_from_ΔU_imp
273+
dss!(u, p, t_final)
274+
post_implicit!(u, p, t_final)
275+
276+
return u
277+
end

src/solvers/imex_tableaus.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@ default values for `c_exp` and `c_imp` assume that it is internally consistent.
1919
The explicit tableau must be strictly lower triangular, and the implicit tableau
2020
must be lower triangular (only DIRK algorithms are currently supported).
2121
"""
22-
struct IMEXTableau{AE <: SPCO, BE <: SPCO, CE <: SPCO, AI <: SPCO, BI <: SPCO, CI <: SPCO}
23-
a_exp::AE # matrix of size s×s
24-
b_exp::BE # vector of length s
25-
c_exp::CE # vector of length s
26-
a_imp::AI # matrix of size s×s
27-
b_imp::BI # vector of length s
28-
c_imp::CI # vector of length s
29-
end
30-
IMEXTableau(args...) = IMEXTableau(map(x -> SparseCoeffs(x), args)...)
22+
struct IMEXTableau{M, V}
23+
a_exp::M # matrix of size s×s
24+
b_exp::V # vector of length s
25+
c_exp::V # vector of length s
26+
a_imp::M # matrix of size s×s
27+
b_imp::V # vector of length s
28+
c_imp::V # vector of length s
29+
end
3130

3231
function IMEXTableau(;
3332
a_exp,

0 commit comments

Comments
 (0)