Skip to content

Commit 8df83cc

Browse files
Add support for async f and j
1 parent f2e2b71 commit 8df83cc

File tree

3 files changed

+101
-0
lines changed

3 files changed

+101
-0
lines changed

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f"
1717
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1818
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1919

20+
[weakdeps]
21+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
22+
23+
[extensions]
24+
CudaExt = "CUDA"
25+
2026
[compat]
2127
ClimaComms = "0.4, 0.5"
2228
Colors = "0.12"

ext/CudaExt.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
module CudaExt
2+
3+
import CUDA
4+
import ClimaComms: SingletonCommsContext, CUDADevice
5+
import ClimaTimeSteppers: compute_T_lim_T_exp!
6+
7+
@inline function compute_T_lim_T_exp!(T_lim, T_exp, U, p, t, T_lim!, T_exp!, ::SingletonCommsContext{CUDADevice})
8+
# TODO: we should benchmark these two options to
9+
# see if one is preferrable over the other
10+
if Base.Threads.nthreads() > 1
11+
compute_T_lim_T_exp_spawn!(T_lim, T_exp, U, p, t, T_lim!, T_exp!)
12+
else
13+
compute_T_lim_T_exp_streams!(T_lim, T_exp, U, p, t, T_lim!, T_exp!)
14+
end
15+
end
16+
17+
@inline function compute_T_lim_T_exp_streams!(T_lim, T_exp, U, p, t, T_lim!, T_exp!)
18+
event = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
19+
CUDA.record(event, CUDA.stream()) # record event on main stream
20+
21+
stream1 = CUDA.CuStream() # make a stream
22+
local event1
23+
CUDA.stream!(stream1) do # work to be done by stream1
24+
CUDA.wait(event, stream1) # make stream1 wait on event (host continues)
25+
T_lim!(T_lim, U, p, t)
26+
event1 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
27+
end
28+
CUDA.record(event1, stream1) # record event1 on stream1
29+
30+
stream2 = CUDA.CuStream() # make a stream
31+
local event2
32+
CUDA.stream!(stream2) do # work to be done by stream2
33+
CUDA.wait(event, stream2) # make stream2 wait on event (host continues)
34+
T_exp!(T_exp, U, p, t)
35+
event2 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
36+
end
37+
CUDA.record(event2, stream2) # record event2 on stream2
38+
39+
CUDA.wait(event1, CUDA.stream()) # make main stream wait on event1
40+
CUDA.wait(event2, CUDA.stream()) # make main stream wait on event2
41+
end
42+
43+
@inline function compute_T_lim_T_exp_spawn!(T_lim, T_exp, U, p, t, T_lim!, T_exp!)
44+
45+
CUDA.synchronize()
46+
CUDA.@sync begin
47+
Base.Threads.@spawn begin
48+
T_lim!(T_lim, U, p, t)
49+
CUDA.synchronize()
50+
nothing
51+
end
52+
Base.Threads.@spawn begin
53+
T_exp!(T_exp, U, p, t)
54+
CUDA.synchronize()
55+
nothing
56+
end
57+
end
58+
end
59+
60+
end

src/solvers/compute_T_exp_T_lim.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
@inline function compute_T_lim_T_exp!(
2+
T_lim,
3+
T_exp,
4+
U,
5+
p,
6+
t,
7+
T_lim!,
8+
T_exp!,
9+
::Union{Nothing, ClimaComms.AbstractCommsContext},
10+
)
11+
T_lim!(T_lim, U, p, t)
12+
T_exp!(T_exp, U, p, t)
13+
end
14+
15+
@inline function compute_T_lim_T_exp!(
16+
T_lim,
17+
T_exp,
18+
U,
19+
p,
20+
t,
21+
T_lim!,
22+
T_exp!,
23+
::ClimaComms.SingletonCommsContext{ClimaComms.CPUMultiThreaded},
24+
)
25+
Base.@sync begin
26+
Base.Threads.@spawn begin
27+
T_lim!(T_lim, U, p, t)
28+
nothing
29+
end
30+
Base.Threads.@spawn begin
31+
T_exp!(T_exp, U, p, t)
32+
nothing
33+
end
34+
end
35+
end

0 commit comments

Comments
 (0)