@@ -2,18 +2,48 @@ import CUDAnative, CUDAdrv
2
2
import CUDAnative: cufunction
3
3
import CUDAdrv: CuEvent, CuStream, CuDefaultStream
4
4
5
- STREAMS = CuStream[]
6
- let id = 1
7
- global next_stream
8
- function next_stream ()
9
- global id
10
- stream = STREAMS[id]
11
- if id < length (STREAMS)
12
- id += 1
13
- else
14
- id = 1
5
+ const FREE_STREAMS = CuStream[]
6
+ const STREAMS = CuStream[]
7
+ const STREAM_GC_THRESHOLD = Ref {Int} (16 )
8
+
9
+ @init begin
10
+ if haskey (ENV , " KERNELABSTRACTIONS_STREAMS_GC_THRESHOLD" )
11
+ global STREAM_GC_THRESHOLD[] = parse (Int, ENV [" KERNELABSTRACTIONS_STREAMS_GC_THRESHOLD" ])
12
+ end
13
+
14
+ end
15
+
16
+ # # Stream GC
17
+ # Simplistic stream gc design in which when we have a total number
18
+ # of streams bigger than a threshold, we start scanning the streams
19
+ # and add them back to the freelist if all work on them has completed.
20
+ # Alternative designs:
21
+ # - Enqueue a host function on the stream that adds the stream back to the freelist
22
+ # - Attach a finalizer to events that adds the stream back to the freelist
23
+ # Possible improvements
24
+ # - Add a background task that occasionally scans all streams
25
+ # - Add a hysterisis by checking a "since last scanned" timestamp
26
+ # - Add locking
27
+ function next_stream ()
28
+ if ! isempty (FREE_STREAMS)
29
+ return pop! (FREE_STREAMS)
30
+ end
31
+
32
+ if length (STREAMS) > STREAM_GC_THRESHOLD[]
33
+ for stream in STREAMS
34
+ if CUDAdrv. query (stream)
35
+ push! (FREE_STREAMS, stream)
36
+ end
15
37
end
16
38
end
39
+
40
+ if ! isempty (FREE_STREAMS)
41
+ return pop! (FREE_STREAMS)
42
+ end
43
+
44
+ stream = CUDAdrv. CuStream (CUDAdrv. STREAM_NON_BLOCKING)
45
+ push! (STREAMS, stream)
46
+ return stream
17
47
end
18
48
19
49
struct CudaEvent <: Event
@@ -30,17 +60,6 @@ function wait(ev::CudaEvent, progress=nothing)
30
60
end
31
61
end
32
62
33
- @init begin
34
- if haskey (ENV , " KERNELABSTRACTIONS_STREAMS" )
35
- nstreams = parse (Int, ENV [" KERNELABSTRACTIONS_STREAMS" ])
36
- else
37
- nstreams = 4
38
- end
39
- for i in 1 : nstreams
40
- push! (STREAMS, CuStream (CUDAdrv. STREAM_NON_BLOCKING))
41
- end
42
- end
43
-
44
63
function (obj:: Kernel{CUDA} )(args... ; ndrange= nothing , dependencies= nothing , workgroupsize= nothing )
45
64
if ndrange isa Int
46
65
ndrange = (ndrange,)
@@ -49,13 +68,7 @@ function (obj::Kernel{CUDA})(args...; ndrange=nothing, dependencies=nothing, wor
49
68
dependencies = (dependencies,)
50
69
end
51
70
52
- # Be conservative and launch on CuDefaultStream
53
- if dependencies === nothing
54
- stream = CuDefaultStream ()
55
- else
56
- stream = next_stream ()
57
- end
58
-
71
+ stream = next_stream ()
59
72
if dependencies != = nothing
60
73
for event in dependencies
61
74
@assert event isa CudaEvent
0 commit comments