Skip to content

Commit 9d5769b

Browse files
committed
implement simple stream GC
1 parent 9db56f1 commit 9d5769b

File tree

1 file changed

+41
-28
lines changed

1 file changed

+41
-28
lines changed

src/backends/cuda.jl

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,48 @@ import CUDAnative, CUDAdrv
22
import CUDAnative: cufunction
33
import CUDAdrv: CuEvent, CuStream, CuDefaultStream
44

5-
STREAMS = CuStream[]
6-
let id = 1
7-
global next_stream
8-
function next_stream()
9-
global id
10-
stream = STREAMS[id]
11-
if id < length(STREAMS)
12-
id += 1
13-
else
14-
id = 1
5+
const FREE_STREAMS = CuStream[]
6+
const STREAMS = CuStream[]
7+
const STREAM_GC_THRESHOLD = Ref{Int}(16)
8+
9+
@init begin
10+
if haskey(ENV, "KERNELABSTRACTIONS_STREAMS_GC_THRESHOLD")
11+
global STREAM_GC_THRESHOLD[] = parse(Int, ENV["KERNELABSTRACTIONS_STREAMS_GC_THRESHOLD"])
12+
end
13+
14+
end
15+
16+
## Stream GC
17+
# Simplistic stream gc design in which when we have a total number
18+
# of streams bigger than a threshold, we start scanning the streams
19+
# and add them back to the freelist if all work on them has completed.
20+
# Alternative designs:
21+
# - Enqueue a host function on the stream that adds the stream back to the freelist
22+
# - Attach a finalizer to events that adds the stream back to the freelist
23+
# Possible improvements
24+
# - Add a background task that occasionally scans all streams
25+
# - Add a hysterisis by checking a "since last scanned" timestamp
26+
# - Add locking
27+
function next_stream()
28+
if !isempty(FREE_STREAMS)
29+
return pop!(FREE_STREAMS)
30+
end
31+
32+
if length(STREAMS) > STREAM_GC_THRESHOLD[]
33+
for stream in STREAMS
34+
if CUDAdrv.query(stream)
35+
push!(FREE_STREAMS, stream)
36+
end
1537
end
1638
end
39+
40+
if !isempty(FREE_STREAMS)
41+
return pop!(FREE_STREAMS)
42+
end
43+
44+
stream = CUDAdrv.CuStream(CUDAdrv.STREAM_NON_BLOCKING)
45+
push!(STREAMS, stream)
46+
return stream
1747
end
1848

1949
struct CudaEvent <: Event
@@ -30,17 +60,6 @@ function wait(ev::CudaEvent, progress=nothing)
3060
end
3161
end
3262

33-
@init begin
34-
if haskey(ENV, "KERNELABSTRACTIONS_STREAMS")
35-
nstreams = parse(Int, ENV["KERNELABSTRACTIONS_STREAMS"])
36-
else
37-
nstreams = 4
38-
end
39-
for i in 1:nstreams
40-
push!(STREAMS, CuStream(CUDAdrv.STREAM_NON_BLOCKING))
41-
end
42-
end
43-
4463
function (obj::Kernel{CUDA})(args...; ndrange=nothing, dependencies=nothing, workgroupsize=nothing)
4564
if ndrange isa Int
4665
ndrange = (ndrange,)
@@ -49,13 +68,7 @@ function (obj::Kernel{CUDA})(args...; ndrange=nothing, dependencies=nothing, wor
4968
dependencies = (dependencies,)
5069
end
5170

52-
# Be conservative and launch on CuDefaultStream
53-
if dependencies === nothing
54-
stream = CuDefaultStream()
55-
else
56-
stream = next_stream()
57-
end
58-
71+
stream = next_stream()
5972
if dependencies !== nothing
6073
for event in dependencies
6174
@assert event isa CudaEvent

0 commit comments

Comments
 (0)