Skip to content

Commit 50efa41

Browse files
authored
Merge pull request #10 from JuliaGPU/vc/streamgc
add stream GC and wait with progress function
2 parents eaaf82b + 5cc8dc3 commit 50efa41

File tree

5 files changed

+78
-40
lines changed

5 files changed

+78
-40
lines changed

Project.toml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Valentin Churavy <[email protected]>"]
44
version = "0.1.0"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
89
CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
910
CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
@@ -12,13 +13,14 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1213
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1314

1415
[compat]
15-
CUDAapi = ">= 1.1"
16-
CUDAdrv = ">= 4.0"
17-
CUDAnative = ">= 2.2"
16+
Adapt = "0.4, 1.0"
17+
CUDAapi = "3.0"
18+
CUDAdrv = "6.0"
19+
CUDAnative = "2.10"
20+
Cassette = "0.3"
1821
Requires = "1.0"
19-
julia = ">= 1.3"
20-
Cassette = ">= 0.3"
2122
StaticArrays = "0.12"
23+
julia = "1.3"
2224

2325
[extras]
2426
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"

src/KernelAbstractions.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ end
105105
@synchronize()
106106
"""
107107
macro synchronize()
108-
@error "@synchronize not captured or used outside @kernel"
108+
quote
109+
$__synchronize()
110+
end
109111
end
110112

111113
"""
@@ -286,6 +288,10 @@ function SharedMemory(::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
286288
throw(MethodError(ScratchArray, (T, Val(Dims), Val(Id))))
287289
end
288290

291+
function __synchronize()
292+
error("@synchronize used outside kernel or not captured")
293+
end
294+
289295
###
290296
# Backends/Implementation
291297
###

src/backends/cpu.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,15 @@ struct CPUEvent <: Event
22
task::Core.Task
33
end
44

5-
function wait(ev::CPUEvent)
6-
wait(ev.task)
5+
function wait(ev::CPUEvent, progress=nothing)
6+
if progress === nothing
7+
wait(ev.task)
8+
else
9+
while !Base.istaskdone(ev.task)
10+
progress()
11+
yield() # yield to the scheduler
12+
end
13+
end
714
end
815

916
function (obj::Kernel{CPU})(args...; ndrange=nothing, workgroupsize=nothing, dependencies=nothing)

src/backends/cuda.jl

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,61 @@ import CUDAnative, CUDAdrv
22
import CUDAnative: cufunction
33
import CUDAdrv: CuEvent, CuStream, CuDefaultStream
44

5-
STREAMS = CuStream[]
6-
let id = 1
7-
global next_stream
8-
function next_stream()
9-
global id
10-
stream = STREAMS[id]
11-
if id < length(STREAMS)
12-
id += 1
13-
else
14-
id = 1
5+
const FREE_STREAMS = CuStream[]
6+
const STREAMS = CuStream[]
7+
const STREAM_GC_THRESHOLD = Ref{Int}(16)
8+
9+
@init begin
10+
if haskey(ENV, "KERNELABSTRACTIONS_STREAMS_GC_THRESHOLD")
11+
global STREAM_GC_THRESHOLD[] = parse(Int, ENV["KERNELABSTRACTIONS_STREAMS_GC_THRESHOLD"])
12+
end
13+
14+
end
15+
16+
## Stream GC
17+
# Simplistic stream gc design in which when we have a total number
18+
# of streams bigger than a threshold, we start scanning the streams
19+
# and add them back to the freelist if all work on them has completed.
20+
# Alternative designs:
21+
# - Enqueue a host function on the stream that adds the stream back to the freelist
22+
# - Attach a finalizer to events that adds the stream back to the freelist
23+
# Possible improvements
24+
# - Add a background task that occasionally scans all streams
25+
# - Add a hysterisis by checking a "since last scanned" timestamp
26+
# - Add locking
27+
function next_stream()
28+
if !isempty(FREE_STREAMS)
29+
return pop!(FREE_STREAMS)
30+
end
31+
32+
if length(STREAMS) > STREAM_GC_THRESHOLD[]
33+
for stream in STREAMS
34+
if CUDAdrv.query(stream)
35+
push!(FREE_STREAMS, stream)
36+
end
1537
end
1638
end
39+
40+
if !isempty(FREE_STREAMS)
41+
return pop!(FREE_STREAMS)
42+
end
43+
44+
stream = CUDAdrv.CuStream(CUDAdrv.STREAM_NON_BLOCKING)
45+
push!(STREAMS, stream)
46+
return stream
1747
end
1848

1949
struct CudaEvent <: Event
2050
event::CuEvent
2151
end
22-
function wait(ev::CudaEvent)
23-
# TODO: MPI/libuv progress
24-
CUDAdrv.wait(ev.event)
25-
end
26-
27-
@init begin
28-
if haskey(ENV, "KERNELABSTRACTIONS_STREAMS")
29-
nstreams = parse(Int, ENV["KERNELABSTRACTIONS_STREAMS"])
52+
function wait(ev::CudaEvent, progress=nothing)
53+
if progress === nothing
54+
CUDAdrv.wait(ev.event)
3055
else
31-
nstreams = 4
32-
end
33-
for i in 1:nstreams
34-
push!(STREAMS, CuStream(CUDAdrv.STREAM_NON_BLOCKING))
56+
while !CUDAdrv.query(ev.event)
57+
progress()
58+
# do we need to `yield` here?
59+
end
3560
end
3661
end
3762

@@ -43,13 +68,7 @@ function (obj::Kernel{CUDA})(args...; ndrange=nothing, dependencies=nothing, wor
4368
dependencies = (dependencies,)
4469
end
4570

46-
# Be conservative and launch on CuDefaultStream
47-
if dependencies === nothing
48-
stream = CuDefaultStream()
49-
else
50-
stream = next_stream()
51-
end
52-
71+
stream = next_stream()
5372
if dependencies !== nothing
5473
for event in dependencies
5574
@assert event isa CudaEvent
@@ -182,7 +201,7 @@ end
182201
###
183202
# GPU implementation of shared memory
184203
###
185-
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
204+
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
186205
ptr = CUDAnative._shmem(Val(Id), T, Val(prod(Dims)))
187206
CUDAnative.CuDeviceArray(Dims, CUDAnative.DevicePtr{T, CUDAnative.AS.Shared}(ptr))
188207
end
@@ -192,6 +211,10 @@ end
192211
# - private memory for each workitem
193212
###
194213

195-
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(Scratchpad), ::Type{T}, ::Val{Dims}) where {T, Dims}
214+
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(Scratchpad), ::Type{T}, ::Val{Dims}) where {T, Dims}
196215
MArray{__size(Dims), T}(undef)
197216
end
217+
218+
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__synchronize))
219+
CUDAnative.sync_threads()
220+
end

src/macros.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ function split(stmts)
8888

8989
for stmt in stmts.args
9090
if isexpr(stmt, :macrocall) && stmt.args[1] === Symbol("@synchronize")
91-
push!(loops, (current, copy(indicies), allocations))
91+
push!(loops, (current, deepcopy(indicies), allocations))
9292
allocations = Any[]
9393
current = Any[]
9494
continue

0 commit comments

Comments
 (0)