@@ -50,6 +50,7 @@ function step_u!(integrator, cache::IMEXARKCache)
5050 (; u, p, t, dt, alg) = integrator
5151 (; f) = integrator. sol. prob
5252 (; post_explicit!, post_implicit!) = f
53+ (; comms_context) = f
5354 (; T_lim!, T_exp!, T_imp!, lim!, dss!) = f
5455 (; tableau, newtons_method) = alg
5556 (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau
@@ -74,19 +75,17 @@ function step_u!(integrator, cache::IMEXARKCache)
7475
7576 @. U = u
7677
77- if ! isnothing (T_lim!) # Update based on limited tendencies from previous stages
78- for j in 1 : (i - 1 )
79- iszero (a_exp[i, j]) && continue
80- @. U += dt * a_exp[i, j] * T_lim[j]
81- end
82- lim! (U, p, t_exp, u)
78+ # Update based on limited tendencies from previous stages
79+ for j in 1 : (i - 1 )
80+ iszero (a_exp[i, j]) && continue
81+ @. U += dt * a_exp[i, j] * T_lim[j]
8382 end
83+ lim! (U, p, t_exp, u)
8484
85- if ! isnothing (T_exp!) # Update based on explicit tendencies from previous stages
86- for j in 1 : (i - 1 )
87- iszero (a_exp[i, j]) && continue
88- @. U += dt * a_exp[i, j] * T_exp[j]
89- end
85+ # Update based on explicit tendencies from previous stages
86+ for j in 1 : (i - 1 )
87+ iszero (a_exp[i, j]) && continue
88+ @. U += dt * a_exp[i, j] * T_exp[j]
9089 end
9190
9291 if ! isnothing (T_imp!) # Update based on implicit tendencies from previous stages
@@ -147,32 +146,54 @@ function step_u!(integrator, cache::IMEXARKCache)
147146 end
148147
149148 if ! all (iszero, a_exp[:, i]) || ! iszero (b_exp[i])
150- if ! isnothing (T_lim! )
149+ if isnothing (comms_context )
151150 T_lim! (T_lim[i], U, p, t_exp)
152- end
153- if ! isnothing (T_exp!)
154151 T_exp! (T_exp[i], U, p, t_exp)
152+ else # do asynchronously
153+
154+ # https://github.com/JuliaLang/julia/issues/40626
155+ if ClimaComms. device (comms_context) isa CUDA. CUDADevice
156+ CUDA. @sync begin
157+ @async begin
158+ T_lim! (T_lim[i], U, p, t_exp)
159+ nothing
160+ end
161+ @async begin
162+ T_exp! (T_exp[i], U, p, t_exp)
163+ nothing
164+ end
165+ end
166+ else
167+ @sync begin
168+ @async begin
169+ T_lim! (T_lim[i], U, p, t_exp)
170+ nothing
171+ end
172+ @async begin
173+ T_exp! (T_exp[i], U, p, t_exp)
174+ nothing
175+ end
176+ end
177+ end
155178 end
156179 end
157180 end
158181
159182 t_final = t + dt
160183
161- if ! isnothing (T_lim!) # Update based on limited tendencies from previous stages
162- @. temp = u
163- for j in 1 : s
164- iszero (b_exp[j]) && continue
165- @. temp += dt * b_exp[j] * T_lim[j]
166- end
167- lim! (temp, p, t_final, u)
168- @. u = temp
184+ # Update based on limited tendencies from previous stages
185+ @. temp = u
186+ for j in 1 : s
187+ iszero (b_exp[j]) && continue
188+ @. temp += dt * b_exp[j] * T_lim[j]
169189 end
190+ lim! (temp, p, t_final, u)
191+ @. u = temp
170192
171- if ! isnothing (T_exp!) # Update based on explicit tendencies from previous stages
172- for j in 1 : s
173- iszero (b_exp[j]) && continue
174- @. u += dt * b_exp[j] * T_exp[j]
175- end
193+ # Update based on explicit tendencies from previous stages
194+ for j in 1 : s
195+ iszero (b_exp[j]) && continue
196+ @. u += dt * b_exp[j] * T_exp[j]
176197 end
177198
178199 if ! isnothing (T_imp!) # Update based on implicit tendencies from previous stages
0 commit comments