|
| 1 | +module CudaExt |
| 2 | + |
| 3 | +import CUDA |
| 4 | +import ClimaComms: SingletonCommsContext, CUDADevice |
| 5 | +import ClimaTimeSteppers: compute_T_lim_T_exp! |
| 6 | + |
| 7 | +@inline function compute_T_lim_T_exp!(T_lim, T_exp, U, p, t, T_lim!, T_exp!, ::SingletonCommsContext{CUDADevice}) |
| 8 | + # TODO: we should benchmark these two options to |
| 9 | + # see if one is preferrable over the other |
| 10 | + if Base.Threads.nthreads() > 1 |
| 11 | + compute_T_lim_T_exp_spawn!(T_lim, T_exp, U, p, t, T_lim!, T_exp!) |
| 12 | + else |
| 13 | + compute_T_lim_T_exp_streams!(T_lim, T_exp, U, p, t, T_lim!, T_exp!) |
| 14 | + end |
| 15 | +end |
| 16 | + |
| 17 | +@inline function compute_T_lim_T_exp_streams!(T_lim, T_exp, U, p, t, T_lim!, T_exp!) |
| 18 | + event = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING) |
| 19 | + CUDA.record(event, CUDA.stream()) # record event on main stream |
| 20 | + |
| 21 | + stream1 = CUDA.CuStream() # make a stream |
| 22 | + local event1 |
| 23 | + CUDA.stream!(stream1) do # work to be done by stream1 |
| 24 | + CUDA.wait(event, stream1) # make stream1 wait on event (host continues) |
| 25 | + T_lim!(T_lim, U, p, t) |
| 26 | + event1 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING) |
| 27 | + end |
| 28 | + CUDA.record(event1, stream1) # record event1 on stream1 |
| 29 | + |
| 30 | + stream2 = CUDA.CuStream() # make a stream |
| 31 | + local event2 |
| 32 | + CUDA.stream!(stream2) do # work to be done by stream2 |
| 33 | + CUDA.wait(event, stream2) # make stream2 wait on event (host continues) |
| 34 | + T_exp!(T_exp, U, p, t) |
| 35 | + event2 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING) |
| 36 | + end |
| 37 | + CUDA.record(event2, stream2) # record event2 on stream2 |
| 38 | + |
| 39 | + CUDA.wait(event1, CUDA.stream()) # make main stream wait on event1 |
| 40 | + CUDA.wait(event2, CUDA.stream()) # make main stream wait on event2 |
| 41 | +end |
| 42 | + |
| 43 | +@inline function compute_T_lim_T_exp_spawn!(T_lim, T_exp, U, p, t, T_lim!, T_exp!) |
| 44 | + |
| 45 | + CUDA.synchronize() |
| 46 | + CUDA.@sync begin |
| 47 | + Base.Threads.@spawn begin |
| 48 | + T_lim!(T_lim, U, p, t) |
| 49 | + CUDA.synchronize() |
| 50 | + nothing |
| 51 | + end |
| 52 | + Base.Threads.@spawn begin |
| 53 | + T_exp!(T_exp, U, p, t) |
| 54 | + CUDA.synchronize() |
| 55 | + nothing |
| 56 | + end |
| 57 | + end |
| 58 | +end |
| 59 | + |
| 60 | +end |
0 commit comments