Skip to content

Commit 96b7067

Browse files
committed
Add DSS before implicit solve, fix T_imp approximation and callbacks
1 parent 5f1503b commit 96b7067

File tree

5 files changed

+63
-99
lines changed

5 files changed

+63
-99
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ClimaTimeSteppers"
22
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
33
authors = ["Climate Modeling Alliance"]
4-
version = "0.7.39"
4+
version = "0.7.40"
55

66
[deps]
77
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"

src/solvers/hard_coded_ars343.jl

Lines changed: 23 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
2222

2323
i::Int = 1
2424
t_exp = t
25-
@. U = u
26-
lim!(U, p, t_exp, u)
27-
dss!(U, p, t_exp)
25+
@. U = u # TODO: This is unnecessary; we can just pass u to T_exp and T_lim
2826
T_lim!(T_lim[i], U, p, t_exp)
2927
T_exp!(T_exp[i], U, p, t_exp)
3028

@@ -33,38 +31,32 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
3331
@. U = u + dt * a_exp[i, 1] * T_lim[1]
3432
lim!(U, p, t_exp, u)
3533
@. U += dt * a_exp[i, 1] * T_exp[1]
36-
post_explicit!(U, p, t_exp)
37-
34+
dss!(U, p, t_exp)
3835
@. temp = U # used in closures
3936
let i = i
4037
t_imp = t + dt * c_imp[i]
38+
post_implicit!(U, p, t_imp)
4139
implicit_equation_residual! = (residual, Ui) -> begin
4240
T_imp!(residual, Ui, p, t_imp)
4341
@. residual = temp + dt * a_imp[i, i] * residual - Ui
4442
end
4543
implicit_equation_jacobian! = (jacobian, Ui) -> begin
4644
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
4745
end
48-
call_post_implicit! = Ui -> begin
49-
post_implicit!(Ui, p, t_imp)
50-
end
51-
call_post_implicit_last! = Ui -> begin
52-
dss!(Ui, p, t_imp)
53-
post_implicit!(Ui, p, t_imp)
54-
end
46+
call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp)
5547
solve_newton!(
5648
newtons_method,
5749
newtons_method_cache,
5850
U,
5951
implicit_equation_residual!,
6052
implicit_equation_jacobian!,
6153
call_post_implicit!,
62-
call_post_implicit_last!,
54+
nothing,
6355
)
56+
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
57+
dss!(U, p, t_imp)
58+
post_explicit!(U, p, t_imp)
6459
end
65-
66-
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
67-
6860
T_lim!(T_lim[i], U, p, t_exp)
6961
T_exp!(T_exp[i], U, p, t_exp)
7062

@@ -73,40 +65,35 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
7365
@. U = u + dt * a_exp[i, 1] * T_lim[1] + dt * a_exp[i, 2] * T_lim[2]
7466
lim!(U, p, t_exp, u)
7567
@. U += dt * a_exp[i, 1] * T_exp[1] + dt * a_exp[i, 2] * T_exp[2] + dt * a_imp[i, 2] * T_imp[2]
76-
post_explicit!(U, p, t_exp)
77-
68+
dss!(U, p, t_exp)
7869
@. temp = U # used in closures
7970
let i = i
8071
t_imp = t + dt * c_imp[i]
72+
post_implicit!(U, p, t_imp)
8173
implicit_equation_residual! = (residual, Ui) -> begin
8274
T_imp!(residual, Ui, p, t_imp)
8375
@. residual = temp + dt * a_imp[i, i] * residual - Ui
8476
end
8577
implicit_equation_jacobian! = (jacobian, Ui) -> begin
8678
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
8779
end
88-
call_post_implicit! = Ui -> begin
89-
post_implicit!(Ui, p, t_imp)
90-
end
91-
call_post_implicit_last! = Ui -> begin
92-
dss!(Ui, p, t_imp)
93-
post_implicit!(Ui, p, t_imp)
94-
end
80+
call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp)
9581
solve_newton!(
9682
newtons_method,
9783
newtons_method_cache,
9884
U,
9985
implicit_equation_residual!,
10086
implicit_equation_jacobian!,
10187
call_post_implicit!,
102-
call_post_implicit_last!,
88+
nothing,
10389
)
90+
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
91+
dss!(U, p, t_imp)
92+
post_explicit!(U, p, t_imp)
10493
end
105-
106-
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
107-
10894
T_lim!(T_lim[i], U, p, t_exp)
10995
T_exp!(T_exp[i], U, p, t_exp)
96+
11097
i = 4
11198
t_exp = t + dt
11299
@. U = u + dt * a_exp[i, 1] * T_lim[1] + dt * a_exp[i, 2] * T_lim[2] + dt * a_exp[i, 3] * T_lim[3]
@@ -117,44 +104,35 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
117104
dt * a_exp[i, 3] * T_exp[3] +
118105
dt * a_imp[i, 2] * T_imp[2] +
119106
dt * a_imp[i, 3] * T_imp[3]
120-
post_explicit!(U, p, t_exp)
121-
107+
dss!(U, p, t_exp)
122108
@. temp = U # used in closures
123109
let i = i
124110
t_imp = t + dt * c_imp[i]
111+
post_implicit!(U, p, t_imp)
125112
implicit_equation_residual! = (residual, Ui) -> begin
126113
T_imp!(residual, Ui, p, t_imp)
127114
@. residual = temp + dt * a_imp[i, i] * residual - Ui
128115
end
129116
implicit_equation_jacobian! = (jacobian, Ui) -> begin
130117
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
131118
end
132-
call_post_implicit! = Ui -> begin
133-
post_implicit!(Ui, p, t_imp)
134-
end
135-
call_post_implicit_last! = Ui -> begin
136-
dss!(Ui, p, t_imp)
137-
post_implicit!(Ui, p, t_imp)
138-
end
119+
call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp)
139120
solve_newton!(
140121
newtons_method,
141122
newtons_method_cache,
142123
U,
143124
implicit_equation_residual!,
144125
implicit_equation_jacobian!,
145126
call_post_implicit!,
146-
call_post_implicit_last!,
127+
nothing,
147128
)
129+
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
130+
dss!(U, p, t_imp)
131+
post_explicit!(U, p, t_imp)
148132
end
149-
150-
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
151-
152133
T_lim!(T_lim[i], U, p, t_exp)
153134
T_exp!(T_exp[i], U, p, t_exp)
154135

155-
# final
156-
i = -1
157-
158136
t_final = t + dt
159137
@. temp = u + dt * b_exp[2] * T_lim[2] + dt * b_exp[3] * T_lim[3] + dt * b_exp[4] * T_lim[4]
160138
lim!(temp, p, t_final, u)

src/solvers/imex_ark.jl

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -119,50 +119,43 @@ end
119119
has_T_exp(f) && fused_increment!(U, dt, a_exp, T_exp, Val(i))
120120
isnothing(T_imp!) || fused_increment!(U, dt, a_imp, T_imp, Val(i))
121121

122+
i 1 && dss!(U, p, t_exp)
123+
122124
if isnothing(T_imp!) || iszero(a_imp[i, i])
123-
i 1 && dss!(U, p, t_imp)
124-
i 1 && post_explicit!(U, p, t_imp)
125+
i 1 && post_explicit!(U, p, t_exp)
126+
if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i])
127+
# If its coefficient is 0, T_imp[i] is being treated explicitly.
128+
isnothing(T_imp!) || T_imp!(T_imp[i], U, p, t_imp)
129+
end
125130
else # Implicit solve
126131
@assert !isnothing(newtons_method)
132+
i 1 && post_implicit!(U, p, t_imp)
127133
@. temp = U
128-
# We do not need to apply DSS yet because the implicit solve does not
129-
# involve any horizontal derivatives.
130-
i 1 && post_explicit!(U, p, t_imp)
131134
# TODO: can/should we remove these closures?
132135
implicit_equation_residual! = (residual, Ui) -> begin
133136
T_imp!(residual, Ui, p, t_imp)
134137
@. residual = temp + dt * a_imp[i, i] * residual - Ui
135138
end
136-
implicit_equation_jacobian! = (jacobian, Ui) -> T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
137-
call_post_implicit! = Ui -> begin
138-
post_implicit!(Ui, p, t_imp)
139+
implicit_equation_jacobian! = (jacobian, Ui) -> begin
140+
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
139141
end
140-
call_post_implicit_last! = Ui -> begin
141-
dss!(Ui, p, t_imp)
142-
post_implicit!(Ui, p, t_imp)
143-
end
144-
142+
call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp)
145143
solve_newton!(
146144
newtons_method,
147145
newtons_method_cache,
148146
U,
149147
implicit_equation_residual!,
150148
implicit_equation_jacobian!,
151149
call_post_implicit!,
152-
call_post_implicit_last!,
150+
nothing,
153151
)
154-
end
155-
156-
if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i])
157-
if iszero(a_imp[i, i])
158-
# If its coefficient is 0, T_imp[i] is effectively being
159-
# treated explicitly.
160-
isnothing(T_imp!) || T_imp!(T_imp[i], U, p, t_imp)
161-
else
152+
if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i])
162153
# If T_imp[i] is being treated implicitly, ensure that it
163-
# exactly satisfies the implicit equation.
164-
isnothing(T_imp!) || @. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
154+
# exactly satisfies the implicit equation before applying DSS.
155+
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
165156
end
157+
dss!(U, p, t_imp)
158+
post_explicit!(U, p, t_imp)
166159
end
167160

168161
if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i])

src/solvers/imex_ssprk.jl

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -102,50 +102,43 @@ function step_u!(integrator, cache::IMEXSSPRKCache)
102102
end
103103
end
104104

105+
i 1 && dss!(U, p, t_exp)
106+
105107
if isnothing(T_imp!) || iszero(a_imp[i, i])
106-
i 1 && dss!(U, p, t_imp)
107-
i 1 && post_explicit!(U, p, t_imp)
108+
i 1 && post_explicit!(U, p, t_exp)
109+
if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i])
110+
# If its coefficient is 0, T_imp[i] is being treated explicitly.
111+
isnothing(T_imp!) || T_imp!(T_imp[i], U, p, t_imp)
112+
end
108113
else # Implicit solve
109114
@assert !isnothing(newtons_method)
115+
i 1 && post_implicit!(U, p, t_imp)
110116
@. temp = U
111-
# We do not need to apply DSS yet because the implicit solve does
112-
# not involve any horizontal derivatives.
113-
post_explicit!(U, p, t_imp)
114117
# TODO: can/should we remove these closures?
115118
implicit_equation_residual! = (residual, Ui) -> begin
116119
T_imp!(residual, Ui, p, t_imp)
117120
@. residual = temp + dt * a_imp[i, i] * residual - Ui
118121
end
119-
implicit_equation_jacobian! = (jacobian, Ui) -> T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
120-
call_post_implicit! = Ui -> begin
121-
post_implicit!(Ui, p, t_imp)
122-
end
123-
call_post_implicit_last! = Ui -> begin
124-
dss!(Ui, p, t_imp)
125-
post_implicit!(Ui, p, t_imp)
122+
implicit_equation_jacobian! = (jacobian, Ui) -> begin
123+
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
126124
end
127-
125+
call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp)
128126
solve_newton!(
129127
newtons_method,
130128
newtons_method_cache,
131129
U,
132130
implicit_equation_residual!,
133131
implicit_equation_jacobian!,
134132
call_post_implicit!,
135-
call_post_implicit_last!,
133+
nothing,
136134
)
137-
end
138-
139-
if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i])
140-
if iszero(a_imp[i, i])
141-
# If its coefficient is 0, T_imp[i] is effectively being
142-
# treated explicitly.
143-
isnothing(T_imp!) || T_imp!(T_imp[i], U, p, t_imp)
144-
else
135+
if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i])
145136
# If T_imp[i] is being treated implicitly, ensure that it
146-
# exactly satisfies the implicit equation.
147-
isnothing(T_imp!) || @. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
137+
# exactly satisfies the implicit equation before applying DSS.
138+
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
148139
end
140+
dss!(U, p, t_imp)
141+
post_explicit!(U, p, t_imp)
149142
end
150143

151144
if !iszero(β[i])

src/solvers/rosenbrock.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages}
123123
T_exp_lim! = int.sol.prob.f.T_exp_T_lim!
124124
tgrad! = isnothing(T_imp!) ? nothing : T_imp!.tgrad
125125

126-
(; post_explicit!, post_implicit!, dss!) = int.sol.prob.f
126+
(; post_explicit!, dss!) = int.sol.prob.f
127127

128128
# TODO: This is only valid when Γ[i, i] is constant, otherwise we have to
129129
# move this in the for loop
@@ -150,15 +150,15 @@ function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages}
150150
U .+= A[i, j] .* k[j]
151151
end
152152

153-
# NOTE: post_implicit! is a misnomer
154-
if !isnothing(post_implicit!)
153+
# NOTE: post_explicit! is a misnomer; should be post_stage!
154+
if !isnothing(post_explicit!)
155155
# We apply DSS and update p on every stage but the first, and at the
156156
# end of each timestep. Since the first stage is unchanged from the
157157
# end of the previous timestep, this order of operations ensures
158158
# that the state is always continuous and that p is consistent with
159159
# the state, including between timesteps.
160160
(i != 1) && dss!(U, p, t + αi * dt)
161-
(i != 1) && post_implicit!(U, p, t + αi * dt)
161+
(i != 1) && post_explicit!(U, p, t + αi * dt)
162162
end
163163

164164
if !isnothing(T_imp!)
@@ -203,7 +203,7 @@ function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages}
203203
end
204204

205205
dss!(u, p, t + dt)
206-
post_implicit!(u, p, t + dt)
206+
post_explicit!(u, p, t + dt)
207207
return nothing
208208
end
209209

0 commit comments

Comments
 (0)