@@ -2,36 +2,61 @@ import CUDAnative, CUDAdrv
2
2
import CUDAnative: cufunction
3
3
import CUDAdrv: CuEvent, CuStream, CuDefaultStream
4
4
5
- STREAMS = CuStream[]
6
- let id = 1
7
- global next_stream
8
- function next_stream ()
9
- global id
10
- stream = STREAMS[id]
11
- if id < length (STREAMS)
12
- id += 1
13
- else
14
- id = 1
5
+ const FREE_STREAMS = CuStream[]
6
+ const STREAMS = CuStream[]
7
+ const STREAM_GC_THRESHOLD = Ref {Int} (16 )
8
+
9
+ @init begin
10
+ if haskey (ENV , " KERNELABSTRACTIONS_STREAMS_GC_THRESHOLD" )
11
+ global STREAM_GC_THRESHOLD[] = parse (Int, ENV [" KERNELABSTRACTIONS_STREAMS_GC_THRESHOLD" ])
12
+ end
13
+
14
+ end
15
+
16
+ # # Stream GC
17
+ # Simplistic stream gc design in which when we have a total number
18
+ # of streams bigger than a threshold, we start scanning the streams
19
+ # and add them back to the freelist if all work on them has completed.
20
+ # Alternative designs:
21
+ # - Enqueue a host function on the stream that adds the stream back to the freelist
22
+ # - Attach a finalizer to events that adds the stream back to the freelist
23
+ # Possible improvements
24
+ # - Add a background task that occasionally scans all streams
25
+ # - Add a hysterisis by checking a "since last scanned" timestamp
26
+ # - Add locking
27
+ function next_stream ()
28
+ if ! isempty (FREE_STREAMS)
29
+ return pop! (FREE_STREAMS)
30
+ end
31
+
32
+ if length (STREAMS) > STREAM_GC_THRESHOLD[]
33
+ for stream in STREAMS
34
+ if CUDAdrv. query (stream)
35
+ push! (FREE_STREAMS, stream)
36
+ end
15
37
end
16
38
end
39
+
40
+ if ! isempty (FREE_STREAMS)
41
+ return pop! (FREE_STREAMS)
42
+ end
43
+
44
+ stream = CUDAdrv. CuStream (CUDAdrv. STREAM_NON_BLOCKING)
45
+ push! (STREAMS, stream)
46
+ return stream
17
47
end
18
48
19
49
struct CudaEvent <: Event
20
50
event:: CuEvent
21
51
end
22
- function wait (ev:: CudaEvent )
23
- # TODO : MPI/libuv progress
24
- CUDAdrv. wait (ev. event)
25
- end
26
-
27
- @init begin
28
- if haskey (ENV , " KERNELABSTRACTIONS_STREAMS" )
29
- nstreams = parse (Int, ENV [" KERNELABSTRACTIONS_STREAMS" ])
52
+ function wait (ev:: CudaEvent , progress= nothing )
53
+ if progress === nothing
54
+ CUDAdrv. wait (ev. event)
30
55
else
31
- nstreams = 4
32
- end
33
- for i in 1 : nstreams
34
- push! (STREAMS, CuStream (CUDAdrv . STREAM_NON_BLOCKING))
56
+ while ! CUDAdrv . query (ev . event)
57
+ progress ()
58
+ # do we need to `yield` here?
59
+ end
35
60
end
36
61
end
37
62
@@ -43,13 +68,7 @@ function (obj::Kernel{CUDA})(args...; ndrange=nothing, dependencies=nothing, wor
43
68
dependencies = (dependencies,)
44
69
end
45
70
46
- # Be conservative and launch on CuDefaultStream
47
- if dependencies === nothing
48
- stream = CuDefaultStream ()
49
- else
50
- stream = next_stream ()
51
- end
52
-
71
+ stream = next_stream ()
53
72
if dependencies != = nothing
54
73
for event in dependencies
55
74
@assert event isa CudaEvent
182
201
# ##
183
202
# GPU implementation of shared memory
184
203
# ##
185
- @inline function Cassette. overdub (ctx:: CPUCtx , :: typeof (SharedMemory), :: Type{T} , :: Val{Dims} , :: Val{Id} ) where {T, Dims, Id}
204
+ @inline function Cassette. overdub (ctx:: CUDACtx , :: typeof (SharedMemory), :: Type{T} , :: Val{Dims} , :: Val{Id} ) where {T, Dims, Id}
186
205
ptr = CUDAnative. _shmem (Val (Id), T, Val (prod (Dims)))
187
206
CUDAnative. CuDeviceArray (Dims, CUDAnative. DevicePtr {T, CUDAnative.AS.Shared} (ptr))
188
207
end
192
211
# - private memory for each workitem
193
212
# ##
194
213
195
- @inline function Cassette. overdub (ctx:: CPUCtx , :: typeof (Scratchpad), :: Type{T} , :: Val{Dims} ) where {T, Dims}
214
+ @inline function Cassette. overdub (ctx:: CUDACtx , :: typeof (Scratchpad), :: Type{T} , :: Val{Dims} ) where {T, Dims}
196
215
MArray {__size(Dims), T} (undef)
197
216
end
217
+
218
+ @inline function Cassette. overdub (ctx:: CUDACtx , :: typeof (__synchronize))
219
+ CUDAnative. sync_threads ()
220
+ end
0 commit comments