Skip to content

Commit cbbb0db

Browse files
Support asynchronous T_lim and T_exp
Bump patch version
1 parent 6839f2e commit cbbb0db

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

src/functions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@ export ClimaODEFunction, ForwardEulerODEFunction
44

55
abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end
66

7-
Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction
7+
Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI, CC} <: AbstractClimaODEFunction
88
T_lim!::TL = nothing # nothing or (uₜ, u, p, t) -> ...
99
T_exp!::TE = nothing # nothing or (uₜ, u, p, t) -> ...
1010
T_imp!::TI = nothing # nothing or (uₜ, u, p, t) -> ...
1111
lim!::L = (u, p, t, u_ref) -> nothing
1212
dss!::D = (u, p, t) -> nothing
1313
post_explicit!::PE = (u, p, t) -> nothing
1414
post_implicit!::PI = (u, p, t) -> nothing
15+
comms_context::CC = nothing
1516
end
1617

1718
# Don't wrap a AbstractClimaODEFunction in an ODEFunction (makes ODEProblem work).

src/solvers/imex_ark.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ function step_u!(integrator, cache::IMEXARKCache)
5050
(; u, p, t, dt, alg) = integrator
5151
(; f) = integrator.sol.prob
5252
(; post_explicit!, post_implicit!) = f
53+
(; comms_context) = f
5354
(; T_lim!, T_exp!, T_imp!, lim!, dss!) = f
5455
(; tableau, newtons_method) = alg
5556
(; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau
@@ -147,11 +148,21 @@ function step_u!(integrator, cache::IMEXARKCache)
147148
end
148149

149150
if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i])
150-
if !isnothing(T_lim!)
151-
T_lim!(T_lim[i], U, p, t_exp)
152-
end
153-
if !isnothing(T_exp!)
154-
T_exp!(T_exp[i], U, p, t_exp)
151+
if !isnothing(T_lim!) && !isnothing(T_lim!) && !isnothing(comms_context)
152+
dev = ClimaComms.device(comms_context)
153+
ClimaComms.@sync dev begin
154+
@async begin
155+
T_lim!(T_lim[i], U, p, t_exp)
156+
nothing
157+
end
158+
@async begin
159+
T_exp!(T_exp[i], U, p, t_exp)
160+
nothing
161+
end
162+
end
163+
else
164+
isnothing(T_lim!) || T_lim!(T_lim[i], U, p, t_exp)
165+
isnothing(T_exp!) || T_exp!(T_exp[i], U, p, t_exp)
155166
end
156167
end
157168
end

src/solvers/imex_ssprk.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ function step_u!(integrator, cache::IMEXSSPRKCache)
5656
(; u, p, t, dt, alg) = integrator
5757
(; f) = integrator.sol.prob
5858
(; post_explicit!, post_implicit!) = f
59+
(; comms_context) = f
5960
(; T_lim!, T_exp!, T_imp!, lim!, dss!) = f
6061
(; tableau, newtons_method) = alg
6162
(; a_imp, b_imp, c_exp, c_imp) = tableau
@@ -153,11 +154,21 @@ function step_u!(integrator, cache::IMEXSSPRKCache)
153154
end
154155

155156
if !iszero(β[i])
156-
if !isnothing(T_lim!)
157-
T_lim!(T_lim, U, p, t_exp)
158-
end
159-
if !isnothing(T_exp!)
160-
T_exp!(T_exp, U, p, t_exp)
157+
if !isnothing(T_lim!) && !isnothing(T_lim!) && !isnothing(comms_context)
158+
dev = ClimaComms.device(comms_context)
159+
ClimaComms.@sync dev begin
160+
@async begin
161+
T_lim!(T_lim, U, p, t_exp)
162+
nothing
163+
end
164+
@async begin
165+
T_exp!(T_exp, U, p, t_exp)
166+
nothing
167+
end
168+
end
169+
else
170+
isnothing(T_lim!) || T_lim!(T_lim, U, p, t_exp)
171+
isnothing(T_exp!) || T_exp!(T_exp, U, p, t_exp)
161172
end
162173
end
163174
end

0 commit comments

Comments
 (0)