Skip to content

Commit c1a1157

Browse files
Support asynchronous T_lim and T_exp
1 parent 6839f2e commit c1a1157

File tree

3 files changed

+96
-56
lines changed

3 files changed

+96
-56
lines changed

src/functions.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@ export ClimaODEFunction, ForwardEulerODEFunction
44

55
abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end
66

7-
Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction
8-
T_lim!::TL = nothing # nothing or (uₜ, u, p, t) -> ...
9-
T_exp!::TE = nothing # nothing or (uₜ, u, p, t) -> ...
7+
Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI, CC} <: AbstractClimaODEFunction
8+
T_lim!::TL = (uₜ, u, p, t) -> nothing
9+
T_exp!::TE = (uₜ, u, p, t) -> nothing
1010
T_imp!::TI = nothing # nothing or (uₜ, u, p, t) -> ...
1111
lim!::L = (u, p, t, u_ref) -> nothing
1212
dss!::D = (u, p, t) -> nothing
1313
post_explicit!::PE = (u, p, t) -> nothing
1414
post_implicit!::PI = (u, p, t) -> nothing
15+
comms_context::CC = nothing
1516
end
1617

1718
# Don't wrap a AbstractClimaODEFunction in an ODEFunction (makes ODEProblem work).

src/solvers/imex_ark.jl

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ function step_u!(integrator, cache::IMEXARKCache)
5050
(; u, p, t, dt, alg) = integrator
5151
(; f) = integrator.sol.prob
5252
(; post_explicit!, post_implicit!) = f
53+
(; comms_context) = f
5354
(; T_lim!, T_exp!, T_imp!, lim!, dss!) = f
5455
(; tableau, newtons_method) = alg
5556
(; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau
@@ -74,19 +75,17 @@ function step_u!(integrator, cache::IMEXARKCache)
7475

7576
@. U = u
7677

77-
if !isnothing(T_lim!) # Update based on limited tendencies from previous stages
78-
for j in 1:(i - 1)
79-
iszero(a_exp[i, j]) && continue
80-
@. U += dt * a_exp[i, j] * T_lim[j]
81-
end
82-
lim!(U, p, t_exp, u)
78+
# Update based on limited tendencies from previous stages
79+
for j in 1:(i - 1)
80+
iszero(a_exp[i, j]) && continue
81+
@. U += dt * a_exp[i, j] * T_lim[j]
8382
end
83+
lim!(U, p, t_exp, u)
8484

85-
if !isnothing(T_exp!) # Update based on explicit tendencies from previous stages
86-
for j in 1:(i - 1)
87-
iszero(a_exp[i, j]) && continue
88-
@. U += dt * a_exp[i, j] * T_exp[j]
89-
end
85+
# Update based on explicit tendencies from previous stages
86+
for j in 1:(i - 1)
87+
iszero(a_exp[i, j]) && continue
88+
@. U += dt * a_exp[i, j] * T_exp[j]
9089
end
9190

9291
if !isnothing(T_imp!) # Update based on implicit tendencies from previous stages
@@ -147,32 +146,54 @@ function step_u!(integrator, cache::IMEXARKCache)
147146
end
148147

149148
if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i])
150-
if !isnothing(T_lim!)
149+
if isnothing(comms_context)
151150
T_lim!(T_lim[i], U, p, t_exp)
152-
end
153-
if !isnothing(T_exp!)
154151
T_exp!(T_exp[i], U, p, t_exp)
152+
else # do asynchronously
153+
154+
# https://github.com/JuliaLang/julia/issues/40626
155+
if ClimaComms.device(comms_context) isa CUDA.CUDADevice
156+
CUDA.@sync begin
157+
@async begin
158+
T_lim!(T_lim[i], U, p, t_exp)
159+
nothing
160+
end
161+
@async begin
162+
T_exp!(T_exp[i], U, p, t_exp)
163+
nothing
164+
end
165+
end
166+
else
167+
@sync begin
168+
@async begin
169+
T_lim!(T_lim[i], U, p, t_exp)
170+
nothing
171+
end
172+
@async begin
173+
T_exp!(T_exp[i], U, p, t_exp)
174+
nothing
175+
end
176+
end
177+
end
155178
end
156179
end
157180
end
158181

159182
t_final = t + dt
160183

161-
if !isnothing(T_lim!) # Update based on limited tendencies from previous stages
162-
@. temp = u
163-
for j in 1:s
164-
iszero(b_exp[j]) && continue
165-
@. temp += dt * b_exp[j] * T_lim[j]
166-
end
167-
lim!(temp, p, t_final, u)
168-
@. u = temp
184+
# Update based on limited tendencies from previous stages
185+
@. temp = u
186+
for j in 1:s
187+
iszero(b_exp[j]) && continue
188+
@. temp += dt * b_exp[j] * T_lim[j]
169189
end
190+
lim!(temp, p, t_final, u)
191+
@. u = temp
170192

171-
if !isnothing(T_exp!) # Update based on explicit tendencies from previous stages
172-
for j in 1:s
173-
iszero(b_exp[j]) && continue
174-
@. u += dt * b_exp[j] * T_exp[j]
175-
end
193+
# Update based on explicit tendencies from previous stages
194+
for j in 1:s
195+
iszero(b_exp[j]) && continue
196+
@. u += dt * b_exp[j] * T_exp[j]
176197
end
177198

178199
if !isnothing(T_imp!) # Update based on implicit tendencies from previous stages

src/solvers/imex_ssprk.jl

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{SSP}
1919
s = length(b_exp)
2020
inds = ntuple(i -> i, s)
2121
inds_T_imp = filter(i -> !all(iszero, a_imp[:, i]) || !iszero(b_imp[i]), inds)
22-
U = similar(u0)
23-
U_exp = similar(u0)
24-
T_lim = similar(u0)
25-
T_exp = similar(u0)
26-
U_lim = similar(u0)
27-
T_imp = SparseContainer(map(i -> similar(u0), collect(1:length(inds_T_imp))), inds_T_imp)
28-
temp = similar(u0)
22+
U = zero(u0)
23+
U_exp = zero(u0)
24+
T_lim = zero(u0)
25+
T_exp = zero(u0)
26+
U_lim = zero(u0)
27+
T_imp = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_imp))), inds_T_imp)
28+
temp = zero(u0)
2929
â_exp = vcat(a_exp, b_exp')
3030
β = diag(â_exp, -1)
3131
for i in 1:length(β)
@@ -56,6 +56,7 @@ function step_u!(integrator, cache::IMEXSSPRKCache)
5656
(; u, p, t, dt, alg) = integrator
5757
(; f) = integrator.sol.prob
5858
(; post_explicit!, post_implicit!) = f
59+
(; comms_context) = f
5960
(; T_lim!, T_exp!, T_imp!, lim!, dss!) = f
6061
(; tableau, newtons_method) = alg
6162
(; a_imp, b_imp, c_exp, c_imp) = tableau
@@ -83,14 +84,10 @@ function step_u!(integrator, cache::IMEXSSPRKCache)
8384
if i == 1
8485
@. U_exp = u
8586
elseif !iszero(β[i - 1])
86-
if !isnothing(T_lim!)
87-
@. U_lim = U_exp + dt * T_lim
88-
lim!(U_lim, p, t_exp, U_exp)
89-
@. U_exp = U_lim
90-
end
91-
if !isnothing(T_exp!)
92-
@. U_exp += dt * T_exp
93-
end
87+
@. U_lim = U_exp + dt * T_lim
88+
lim!(U_lim, p, t_exp, U_exp)
89+
@. U_exp = U_lim
90+
@. U_exp += dt * T_exp
9491
@. U_exp = (1 - β[i - 1]) * u + β[i - 1] * U_exp
9592
end
9693

@@ -153,26 +150,47 @@ function step_u!(integrator, cache::IMEXSSPRKCache)
153150
end
154151

155152
if !iszero(β[i])
156-
if !isnothing(T_lim!)
153+
if isnothing(comms_context)
157154
T_lim!(T_lim, U, p, t_exp)
158-
end
159-
if !isnothing(T_exp!)
160155
T_exp!(T_exp, U, p, t_exp)
156+
else
157+
158+
# https://github.com/JuliaLang/julia/issues/40626
159+
if ClimaComms.device(comms_context) isa CUDA.CUDADevice
160+
CUDA.@sync begin
161+
@async begin
162+
T_lim!(T_lim, U, p, t_exp)
163+
nothing
164+
end
165+
@async begin
166+
T_exp!(T_exp, U, p, t_exp)
167+
nothing
168+
end
169+
end
170+
else
171+
@sync begin
172+
@async begin
173+
T_lim!(T_lim, U, p, t_exp)
174+
nothing
175+
end
176+
@async begin
177+
T_exp!(T_exp, U, p, t_exp)
178+
nothing
179+
end
180+
end
181+
end
182+
161183
end
162184
end
163185
end
164186

165187
t_final = t + dt
166188

167189
if !iszero(β[s])
168-
if !isnothing(T_lim!)
169-
@. U_lim = U_exp + dt * T_lim
170-
lim!(U_lim, p, t_final, U_exp)
171-
@. U_exp = U_lim
172-
end
173-
if !isnothing(T_exp!)
174-
@. U_exp += dt * T_exp
175-
end
190+
@. U_lim = U_exp + dt * T_lim
191+
lim!(U_lim, p, t_final, U_exp)
192+
@. U_exp = U_lim
193+
@. U_exp += dt * T_exp
176194
@. u = (1 - β[s]) * u + β[s] * U_exp
177195
end
178196

0 commit comments

Comments
 (0)