Skip to content

Commit b29df31

Browse files
Add support for async f and j
1 parent f2e2b71 commit b29df31

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f"
1717
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1818
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1919

20+
[weakdeps]
21+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
22+
23+
[extensions]
24+
CudaExt = "CUDA"
25+
2026
[compat]
2127
ClimaComms = "0.4, 0.5"
2228
Colors = "0.12"

ext/CudaExt.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
module CudaExt
2+
3+
import CUDA
4+
import ClimaComms: SingletonCommsContext, CUDADevice
5+
import ClimaTimeSteppers: compute_fj!
6+
7+
@inline function compute_fj!(f, j, U, f!, j!, ::SingletonCommsContext{CUDADevice})
8+
# TODO: we should benchmark these two options to
9+
# see if one is preferrable over the other
10+
if Base.Threads.nthreads() > 1
11+
compute_fj_spawn!(f, j, U, f!, j!)
12+
else
13+
compute_fj_streams!(f, j, U, f!, j!)
14+
end
15+
end
16+
17+
@inline function compute_fj_streams!(f, j, U, f!, j!)
18+
event = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
19+
CUDA.record(event, CUDA.stream()) # record event on main stream
20+
21+
stream1 = CUDA.CuStream() # make a stream
22+
local event1
23+
CUDA.stream!(stream1) do # work to be done by stream1
24+
CUDA.wait(event, stream1) # make stream1 wait on event (host continues)
25+
f!(f, U)
26+
event1 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
27+
end
28+
CUDA.record(event1, stream1) # record event1 on stream1
29+
30+
stream2 = CUDA.CuStream() # make a stream
31+
local event2
32+
CUDA.stream!(stream2) do # work to be done by stream2
33+
CUDA.wait(event, stream2) # make stream2 wait on event (host continues)
34+
j!(j, U)
35+
event2 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
36+
end
37+
CUDA.record(event2, stream2) # record event2 on stream2
38+
39+
CUDA.wait(event1, CUDA.stream()) # make main stream wait on event1
40+
CUDA.wait(event2, CUDA.stream()) # make main stream wait on event2
41+
end
42+
43+
@inline function compute_fj_spawn!(f, j, U, f!, j!)
44+
45+
CUDA.synchronize()
46+
CUDA.@sync begin
47+
Base.Threads.@spawn begin
48+
f!(f, U)
49+
CUDA.synchronize()
50+
nothing
51+
end
52+
Base.Threads.@spawn begin
53+
j!(j, U)
54+
CUDA.synchronize()
55+
nothing
56+
end
57+
end
58+
end
59+
60+
end

src/solvers/compute_fj.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
@inline function compute_fj!(f, j, U, f!, j!, ::Union{Nothing, ClimaComms.AbstractCommsContext})
2+
f!(f, U)
3+
j!(j, U)
4+
end
5+
6+
@inline function compute_fj!(f, j, U, f!, j!, ::ClimaComms.SingletonCommsContext{ClimaComms.CPUMultiThreaded})
7+
Base.@sync begin
8+
Base.Threads.@spawn begin
9+
f!(f, U)
10+
nothing
11+
end
12+
Base.Threads.@spawn begin
13+
j!(j, U)
14+
nothing
15+
end
16+
end
17+
end

0 commit comments

Comments
 (0)