Skip to content

Commit b890aaa

Browse files
committed
Reformulate IMEX ARK to correctly account for Newton residuals and DSS
1 parent f2e2b71 commit b890aaa

File tree

4 files changed

+371
-10
lines changed

4 files changed

+371
-10
lines changed

src/ClimaTimeSteppers.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ array_device(x) = CUDADevice() # assume CUDA
5959

6060
import DiffEqBase, SciMLBase, LinearAlgebra, DiffEqCallbacks, Krylov
6161

62+
include(joinpath("utilities", "sparse_tuple.jl"))
6263
include(joinpath("utilities", "sparse_coeffs.jl"))
6364
include(joinpath("utilities", "fused_increment.jl"))
6465
include("sparse_containers.jl")
@@ -118,7 +119,7 @@ const SPCO = SparseCoeffs
118119

119120
include("solvers/imex_tableaus.jl")
120121
include("solvers/explicit_tableaus.jl")
121-
include("solvers/imex_ark.jl")
122+
include("solvers/imex_ark_new.jl")
122123
include("solvers/imex_ssprk.jl")
123124
include("solvers/multirate.jl")
124125
include("solvers/lsrk.jl")

src/solvers/imex_ark_new.jl

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
has_jac(T_imp!) =
2+
hasfield(typeof(T_imp!), :Wfact) &&
3+
hasfield(typeof(T_imp!), :jac_prototype) &&
4+
!isnothing(T_imp!.Wfact) &&
5+
!isnothing(T_imp!.jac_prototype)
6+
7+
imp_error(name) = error("$(isnothing(name) ? "The given IMEXTableau" : name) \
8+
has implicit stages that require a nonlinear solver, \
9+
so NewtonsMethod must be specified alongside T_imp!.")
10+
11+
sdirk_error(name) = error("$(isnothing(name) ? "The given IMEXTableau" : name) \
12+
has implicit stages with distinct coefficients (it \
13+
is not SDIRK), and an update is required whenever a \
14+
stage has a different coefficient from the previous \
15+
stage. Do not update on the NewTimeStep signal when \
16+
using $(isnothing(name) ? "this tableau" : name).")
17+
18+
struct IMEXARKCache{S, N}
19+
stage_cache::S
20+
newtons_method_cache::N
21+
end
22+
23+
function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{Unconstrained}; kwargs...)
24+
(; u0) = prob
25+
(; T_lim!, T_exp!, T_exp_T_lim!, T_imp!) = prob.f
26+
(; name, newtons_method) = alg
27+
(; a_exp, b_exp, a_imp, b_imp) = alg.tableau
28+
29+
no_T_lim = isnothing(T_lim!) && isnothing(T_exp_T_lim!)
30+
no_T_exp = isnothing(T_exp!) && isnothing(T_exp_T_lim!)
31+
no_T_imp = isnothing(T_imp!)
32+
33+
s = length(a_imp, 1) # number of stages
34+
35+
A_exp = vcat(a_exp, b_exp') # exp coefs with final state seen as stage s + 1
36+
A_imp = vcat(a_imp, b_imp') # imp coefs with final state seen as stage s + 1
37+
Γ = a_imp[nz_stages, nz_stages] # "fully implicit" part of a_imp
38+
sdirk_γ = length(unique(diag(Γ))) == 1 ? diag(Γ)[1] : nothing
39+
40+
z_stages = findall(iszero, diag(a_imp)) # stages without implicit solves
41+
nz_stages = findall(!iszero, diag(a_imp)) # stages with implicit solves
42+
exp_stages = findall(col -> any(!iszero, col), eachcol(A_exp))
43+
imp_z_stages = findall(col -> any(!iszero, col), eachcol(A_imp[:, z_stages]))
44+
45+
temp_value1 = similar(u0)
46+
temp_value2 = similar(u0)
47+
T_lim_by_stage = no_T_lim ? SparseTuple() : SparseTuple(map(_ -> similar(u0), exp_stages), exp_stages)
48+
T_exp_by_stage = no_T_exp ? SparseTuple() : SparseTuple(map(_ -> similar(u0), exp_stages), exp_stages)
49+
T_imp_by_stage = no_T_imp ? SparseTuple() : SparseTuple(map(_ -> similar(u0), imp_z_stages), imp_z_stages)
50+
ΔU_imp_by_stage = no_T_imp ? SparseTuple() : SparseTuple(map(_ -> similar(u0), nz_stages), nz_stages)
51+
52+
isnothing(newtons_method) && !isempty(ΔU_imp_by_stage) && imp_error(name)
53+
54+
prev_ΔtT_lim_to_Δu_coefs_by_stage = no_T_lim ? SparseTuple() : sparse_matrix_rows(A_exp, 1:(s + 1), 1:s)
55+
prev_ΔtT_exp_to_Δu_coefs_by_stage = no_T_exp ? SparseTuple() : sparse_matrix_rows(A_exp, 1:(s + 1), 1:s)
56+
prev_z_ΔtT_imp_to_Δu_coefs_by_stage =
57+
no_T_imp ? SparseTuple() : sparse_matrix_rows(A_imp[:, z_stages], 1:(s + 1), z_stages)
58+
59+
prev_nz_ΔtT_imp_to_Δu_coef_matrix = A_imp[:, nz_stages]
60+
prev_nz_ΔtT_imp_to_Δu_coef_matrix[nz_stages, nz_stages] .= Γ - Diagonal(Γ)
61+
prev_nz_ΔU_imp_to_Δu_coefs_by_stage =
62+
no_T_imp ? SparseTuple() : sparse_matrix_rows(prev_nz_ΔtT_imp_to_Δu_coef_matrix * inv(Γ), 1:(s + 1), nz_stages)
63+
64+
# Convert all inputs to unrolled_foreach in step_u! into regular tuples.
65+
T_lim_by_stage_dense = dense_tuple(T_lim_by_stage, s, nothing)
66+
T_exp_by_stage_dense = dense_tuple(T_exp_by_stage, s, nothing)
67+
T_imp_by_stage_dense = dense_tuple(T_imp_by_stage, s, nothing)
68+
ΔU_imp_by_stage_dense = dense_tuple(ΔU_imp_by_stage, s, nothing)
69+
prev_ΔtT_lim_to_Δu_coefs_by_stage_dense = dense_tuple(prev_ΔtT_lim_to_Δu_coefs_by_stage, s, SparseTuple())
70+
prev_ΔtT_exp_to_Δu_coefs_by_stage_dense = dense_tuple(prev_ΔtT_exp_to_Δu_coefs_by_stage, s, SparseTuple())
71+
prev_z_ΔtT_imp_to_Δu_coefs_by_stage_dense = dense_tuple(prev_z_ΔtT_imp_to_Δu_coefs_by_stage, s, SparseTuple())
72+
prev_nz_ΔU_imp_to_Δu_coefs_by_stage_dense = dense_tuple(prev_nz_ΔU_imp_to_Δu_coefs_by_stage, s, SparseTuple())
73+
74+
stage_cache = (;
75+
s,
76+
sdirk_γ,
77+
temp_value1,
78+
temp_value2,
79+
T_lim_by_stage,
80+
T_exp_by_stage,
81+
T_imp_by_stage,
82+
ΔU_imp_by_stage,
83+
T_lim_by_stage_dense,
84+
T_exp_by_stage_dense,
85+
T_imp_by_stage_dense,
86+
ΔU_imp_by_stage_dense,
87+
prev_ΔtT_lim_to_Δu_coefs_by_stage_dense,
88+
prev_ΔtT_exp_to_Δu_coefs_by_stage_dense,
89+
prev_z_ΔtT_imp_to_Δu_coefs_by_stage_dense,
90+
prev_nz_ΔU_imp_to_Δu_coefs_by_stage_dense,
91+
)
92+
93+
newtons_method_cache =
94+
isnothing(newtons_method) ? nothing :
95+
allocate_cache(newtons_method, u0, has_jac(T_imp!) ? T_imp!.jac_prototype : nothing)
96+
97+
return IMEXARKCache(stage_cache, newtons_method_cache)
98+
end
99+
100+
function step_u!(integrator, cache::IMEXARKCache)
101+
(; u, p, t, alg) = integrator
102+
(; T_lim!, T_exp!, T_exp_T_lim!, T_imp!) = integrator.sol.prob.f
103+
(; lim!, dss!, post_explicit!, post_implicit!) = integrator.sol.prob.f
104+
(; name, newtons_method) = alg
105+
(; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = alg.tableau
106+
(; newtons_method_cache) = cache
107+
(;
108+
s,
109+
sdirk_γ,
110+
temp_value1,
111+
temp_value2,
112+
T_lim_by_stage,
113+
T_exp_by_stage,
114+
T_imp_by_stage,
115+
ΔU_imp_by_stage,
116+
T_lim_by_stage_dense,
117+
T_exp_by_stage_dense,
118+
T_imp_by_stage_dense,
119+
ΔU_imp_by_stage_dense,
120+
prev_ΔtT_lim_to_Δu_coefs_by_stage_dense,
121+
prev_ΔtT_exp_to_Δu_coefs_by_stage_dense,
122+
prev_z_ΔtT_imp_to_Δu_coefs_by_stage_dense,
123+
prev_nz_ΔU_imp_to_Δu_coefs_by_stage_dense,
124+
) = cache.stage_cache
125+
126+
Δt = integrator.dt
127+
128+
if !isnothing(T_imp!) && !isnothing(newtons_method)
129+
(; update_j) = newtons_method
130+
(; j) = newtons_method_cache
131+
if !isnothing(j) && needs_update!(update_j, NewTimeStep(t))
132+
isnothing(sdirk_γ) && sdirk_error(name)
133+
T_imp!.Wfact(j, u, p, Δt * sdirk_γ, t)
134+
end
135+
end
136+
137+
unrolled_foreach(
138+
1:s,
139+
T_lim_by_stage_dense[1:s],
140+
T_exp_by_stage_dense[1:s],
141+
T_imp_by_stage_dense[1:s],
142+
ΔU_imp_by_stage_dense[1:s],
143+
prev_ΔtT_lim_to_Δu_coefs_by_stage_dense[1:s],
144+
prev_ΔtT_exp_to_Δu_coefs_by_stage_dense[1:s],
145+
prev_z_ΔtT_imp_to_Δu_coefs_by_stage_dense[1:s],
146+
prev_nz_ΔU_imp_to_Δu_coefs_by_stage_dense[1:s],
147+
) do (
148+
stage,
149+
T_lim,
150+
T_exp,
151+
T_imp,
152+
ΔU_imp,
153+
prev_ΔtT_lim_to_Δu_coefs,
154+
prev_ΔtT_exp_to_Δu_coefs,
155+
prev_z_ΔtT_imp_to_Δu_coefs,
156+
prev_nz_ΔU_imp_to_Δu_coefs,
157+
)
158+
t_exp = t + Δt * c_exp[stage]
159+
t_imp = t + Δt * c_imp[stage]
160+
Δtγ = Δt * a_imp[stage, stage]
161+
162+
Δu_over_Δt_from_prev_T_lim = broadcasted_dot(prev_ΔtT_lim_to_Δu_coefs, T_lim_by_stage)
163+
Δu_over_Δt_from_prev_T_exp = broadcasted_dot(prev_ΔtT_exp_to_Δu_coefs, T_exp_by_stage)
164+
Δu_over_Δt_from_prev_T_imp = broadcasted_dot(prev_z_ΔtT_imp_to_Δu_coefs, T_imp_by_stage)
165+
Δu_from_prev_ΔU_imp = broadcasted_dot(prev_nz_ΔU_imp_to_Δu_coefs, ΔU_imp_by_stage)
166+
167+
if isempty(prev_ΔtT_lim_to_Δu_coefs)
168+
u_plus_Δu_lim = u
169+
else
170+
u_plus_Δu_lim = temp_value1
171+
@. u_plus_Δu_lim = u + Δt * Δu_over_Δt_from_prev_T_lim
172+
lim!(u_plus_Δu_lim, p, t_exp, u)
173+
end
174+
175+
if (
176+
isempty(prev_ΔtT_exp_to_Δu_coefs) &&
177+
isempty(prev_z_ΔtT_imp_to_Δu_coefs) &&
178+
isempty(prev_nz_ΔU_imp_to_Δu_coefs)
179+
)
180+
U_before_solve = u_plus_Δu_lim
181+
else
182+
U_before_solve = temp_value2
183+
@. U_before_solve =
184+
u_plus_Δu_lim + Δt * (Δu_over_Δt_from_prev_T_exp + Δu_over_Δt_from_prev_T_imp) + Δu_from_prev_ΔU_imp
185+
end
186+
187+
post_explicit!(U_before_solve, p, t_imp)
188+
189+
if !isnothing(ΔU_imp)
190+
U = ΔU_imp # Use ΔU_imp as additional temporary storage.
191+
@. U = U_before_solve
192+
193+
# Solve U ≈ U_before_solve + Δtγ * T_imp(U, p, t_imp) for U.
194+
set_residual! = (residual, U) -> begin
195+
T_imp!(residual, U, p, t_imp)
196+
@. residual = U_before_solve - U + Δtγ * residual
197+
end
198+
set_jacobian! = (jacobian, U) -> T_imp!.Wfact(jacobian, U, p, Δtγ, t_imp)
199+
post_implicit! = U -> post_implicit!(U, p, t_imp)
200+
solve_newton!(
201+
newtons_method,
202+
newtons_method_cache,
203+
U,
204+
set_residual!,
205+
set_jacobian!,
206+
post_implicit!,
207+
post_implicit!,
208+
)
209+
else
210+
U = U_before_solve # There is no solve on this stage.
211+
end
212+
213+
if !isnothing(T_lim) || !isnothing(T_exp)
214+
if !isnothing(T_exp_T_lim!)
215+
T_exp_T_lim!(T_exp, T_lim, U, p, t_exp)
216+
dss!(T_lim, p, t_exp)
217+
dss!(T_exp, p, t_exp)
218+
end
219+
if !isnothing(T_lim!)
220+
T_lim!(T_lim, U, p, t_exp)
221+
dss!(T_lim, p, t_exp)
222+
end
223+
if !isnothing(T_exp!)
224+
T_exp!(T_exp, U, p, t_exp)
225+
dss!(T_exp, p, t_exp)
226+
end
227+
# TODO: Can we just use T_exp_T_lim!, and not 3 different functions?
228+
# TODO: Fuse the DSS calls above into a single operation.
229+
end
230+
231+
if !isnothing(ΔU_imp)
232+
@. ΔU_imp = U - u_plus_Δu_lim - Δt * Δu_over_Δt_from_prev_T_exp
233+
elseif !isnothing(T_imp)
234+
T_imp!(T_imp, U, p, t_imp)
235+
end
236+
end
237+
238+
t_final = t + Δt
239+
prev_ΔtT_lim_to_Δu_coefs_final = prev_ΔtT_lim_to_Δu_coefs_by_stage_dense[s + 1]
240+
prev_ΔtT_exp_to_Δu_coefs_final = prev_ΔtT_exp_to_Δu_coefs_by_stage_dense[s + 1]
241+
prev_z_ΔtT_imp_to_Δu_coefs_final = prev_z_ΔtT_imp_to_Δu_coefs_by_stage_dense[s + 1]
242+
prev_nz_ΔU_imp_to_Δu_coefs_final = prev_nz_ΔU_imp_to_Δu_coefs_by_stage_dense[s + 1]
243+
Δu_over_Δt_from_prev_T_lim_final = broadcasted_dot(prev_ΔtT_lim_to_Δu_coefs_final, T_lim_by_stage)
244+
Δu_over_Δt_from_prev_T_exp_final = broadcasted_dot(prev_ΔtT_exp_to_Δu_coefs_final, T_exp_by_stage)
245+
Δu_over_Δt_from_prev_T_imp_final = broadcasted_dot(prev_z_ΔtT_imp_to_Δu_coefs_final, T_imp_by_stage)
246+
Δu_from_prev_ΔU_imp_final = broadcasted_dot(prev_nz_ΔU_imp_to_Δu_coefs_final, ΔU_imp_by_stage)
247+
248+
if isempty(prev_ΔtT_lim_to_Δu_coefs_final)
249+
u_plus_Δu_lim = u
250+
else
251+
u_plus_Δu_lim = temp_value
252+
@. u_plus_Δu_lim = u + Δt * Δu_over_Δt_from_prev_T_lim_final
253+
lim!(u_plus_Δu_lim, p, t_final, u)
254+
end
255+
256+
@. u =
257+
u_plus_Δu_lim +
258+
Δt * (Δu_over_Δt_from_prev_T_exp_final + Δu_over_Δt_from_prev_T_imp_final) +
259+
Δu_from_prev_ΔU_imp_final
260+
dss!(u, p, t_final)
261+
post_explicit!(u, p, t_final)
262+
263+
return u
264+
end

src/solvers/imex_tableaus.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@ default values for `c_exp` and `c_imp` assume that it is internally consistent.
1919
The explicit tableau must be strictly lower triangular, and the implicit tableau
2020
must be lower triangular (only DIRK algorithms are currently supported).
2121
"""
22-
struct IMEXTableau{AE <: SPCO, BE <: SPCO, CE <: SPCO, AI <: SPCO, BI <: SPCO, CI <: SPCO}
23-
a_exp::AE # matrix of size s×s
24-
b_exp::BE # vector of length s
25-
c_exp::CE # vector of length s
26-
a_imp::AI # matrix of size s×s
27-
b_imp::BI # vector of length s
28-
c_imp::CI # vector of length s
29-
end
30-
IMEXTableau(args...) = IMEXTableau(map(x -> SparseCoeffs(x), args)...)
22+
struct IMEXTableau{M, V}
23+
a_exp::M # matrix of size s×s
24+
b_exp::V # vector of length s
25+
c_exp::V # vector of length s
26+
a_imp::M # matrix of size s×s
27+
b_imp::V # vector of length s
28+
c_imp::V # vector of length s
29+
end
3130

3231
function IMEXTableau(;
3332
a_exp,

0 commit comments

Comments
 (0)