Skip to content

Commit 83d0fd8

Browse files
authored
HandleCache fixes & GPUArrays caching allocator interface implementation (#710)
- Fix rocFFT HandleCache leak. Sometimes handle would not free because it was using the wrong key. - Rework HandleCache. - Remove non-unicode conv aliases (NNlib uses unicode). - Do not use `dlpath` during ROCm discovery.
1 parent 2e5261a commit 83d0fd8

File tree

17 files changed

+188
-415
lines changed

17 files changed

+188
-415
lines changed

Project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,19 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3333
UnsafeAtomics = "013be700-e6cd-48c3-b4a1-df204f14c38f"
3434
UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
3535

36+
[weakdeps]
37+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
38+
39+
[extensions]
40+
AMDGPUChainRulesCoreExt = "ChainRulesCore"
41+
3642
[compat]
3743
AbstractFFTs = "1.0"
3844
AcceleratedKernels = "0.2"
3945
Adapt = "4"
4046
Atomix = "0.1, 1"
4147
CEnum = "0.4, 0.5"
48+
ChainRulesCore = "1"
4249
ExprTools = "0.1"
4350
GPUArrays = "11.1"
4451
GPUCompiler = "0.27, 1.0"

docs/make.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ function main()
2727
"Exceptions" => "exceptions.md",
2828
"Profiling" => "profiling.md",
2929
"Memory" => "memory.md",
30-
"Caching Memory Allocator" => "caching_allocator.md",
3130
"Host-Call" => "hostcall.md",
3231
"Printing" => "printing.md",
3332
"Logging" => "logging.md",

docs/src/caching_allocator.md

Lines changed: 0 additions & 76 deletions
This file was deleted.

ext/AMDGPUChainRulesCoreExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module AMDGPUChainRulesCoreExt
2+
3+
using AMDGPU: ROCArray
4+
5+
import ChainRulesCore
6+
7+
ChainRulesCore.is_inplaceable_destination(::ROCArray) = true
8+
9+
end

src/AMDGPU.jl

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,8 @@ struct LockedObject{T}
3636
lock::ReentrantLock
3737
payload::T
3838
end
39-
4039
LockedObject(payload) = LockedObject(ReentrantLock(), payload)
4140

42-
function Base.lock(f, x::LockedObject)
43-
Base.@lock x.lock begin
44-
return f(x.payload)
45-
end
46-
end
47-
4841
# TODO simplify
4942
struct KernelState
5043
# Exception reporting buffers.
@@ -114,7 +107,6 @@ include("tls.jl")
114107
include("highlevel.jl")
115108
include("reflection.jl")
116109
include("array.jl")
117-
include("caching_allocator.jl")
118110
include("conversions.jl")
119111
include("broadcast.jl")
120112
include("exception_handler.jl")
@@ -139,21 +131,21 @@ include("random.jl")
139131
# Enable hardware FP atomics for +/- ops.
140132
const ROCIndexableRef{Indexable <: ROCDeviceArray} = Atomix.IndexableRef{Indexable}
141133

142-
function Atomix.modify!(
143-
ref::ROCIndexableRef, op::OP, x, ord,
144-
) where OP <: Union{typeof(+), typeof(-)}
134+
function Atomix.modify!(ref::ROCIndexableRef, op::OP, x, ord) where {
135+
OP <: Union{typeof(+), typeof(-)}
136+
}
145137
x = Atomix.asstorable(ref, x)
146138
ptr = Atomix.pointer(ref)
147139
root = Atomix.gcroot(ref)
148-
GC.@preserve root begin
149-
UnsafeAtomics.modify!(ptr, op, x, ord, Val(:agent))
150-
end
140+
GC.@preserve root UnsafeAtomics.modify!(ptr, op, x, ord, Val(:agent))
151141
end
152142

153143
include("ROCKernels.jl")
154144
import .ROCKernels: ROCBackend
155145
export ROCBackend
156146

147+
# include("cache_allocator.jl")
148+
157149
function __init__()
158150
# Used to shutdown hostcalls if any is running.
159151
atexit(() -> begin Runtime.RT_EXITING[] = true end)
@@ -174,7 +166,8 @@ function __init__()
174166
end
175167

176168
if !isempty(libhsaruntime)
177-
HSA.init() == HSA.STATUS_SUCCESS ?
169+
status = HSA.init()
170+
status == HSA.STATUS_SUCCESS ?
178171
atexit(() -> HSA.shut_down()) :
179172
@warn "HSA initialization failed with code $status"
180173
else

src/array.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,22 @@ mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N}
77
::UndefInitializer, dims::Dims{N},
88
) where {T, N, B <: Mem.AbstractAMDBuffer}
99
@assert isbitstype(T) "ROCArray only supports bits types"
10-
11-
alloc_name = cache_alloc_name()
12-
# Do not use caching allocator if it is not set or
13-
# the buffer is not a device memory.
14-
x = if !(B <: Mem.HIPBuffer) || alloc_name == :none
15-
data = DataRef(pool_free, pool_alloc(B, prod(dims) * sizeof(T)))
16-
x = new{T, N, B}(data, dims, 0)
17-
else
18-
alloc = cache_allocator!(alloc_name)
19-
tmp = alloc!(alloc, B, T, dims)
20-
if tmp nothing
21-
data = DataRef(pool_free, pool_alloc(B, prod(dims) * sizeof(T)))
22-
tmp = new{T, N, B}(data, dims, 0)
23-
add_busy!(alloc, tmp)
24-
end
25-
tmp::ROCArray{T, N, B}
10+
function _alloc_f()
11+
sz::Int64 = prod(dims) * sizeof(T)
12+
@debug "Allocate `T=$T`, `dims=$dims`: $(Base.format_bytes(sz))"
13+
data = DataRef(pool_free, pool_alloc(B, sz))
14+
finalizer(unsafe_free!, new{T, N, B}(data, dims, 0))
2615
end
27-
return finalizer(unsafe_free!, x)
16+
return _alloc_f()
17+
18+
# name = GPUArrays.CacheAllocatorName[]
19+
# # Do not use caching allocator if it is not set or
20+
# # the buffer is not a device memory.
21+
# return if !(B <: Mem.HIPBuffer) || name == :none
22+
# _alloc_f()
23+
# else
24+
# GPUArrays.alloc!(_alloc_f, ROCBackend(), name, T, dims)::ROCArray{T, N, B}
25+
# end
2826
end
2927

3028
function ROCArray{T, N}(
@@ -38,9 +36,7 @@ end
3836

3937
GPUArrays.storage(a::ROCArray) = a.buf
4038

41-
function GPUArrays.derive(
42-
::Type{T}, x::ROCArray, dims::Dims{N}, offset::Int,
43-
) where {N, T}
39+
function GPUArrays.derive(::Type{T}, x::ROCArray, dims::Dims{N}, offset::Int) where {N, T}
4440
ref = copy(x.buf)
4541
offset += (x.offset * Base.elsize(x)) ÷ sizeof(T)
4642
ROCArray{T, N}(ref, dims; offset)
@@ -154,6 +150,8 @@ function Base.copyto!(
154150
amount == 0 && return dest
155151
@boundscheck checkbounds(dest, d_offset + amount - 1)
156152
@boundscheck checkbounds(source, s_offset + amount - 1)
153+
154+
@debug "[gpu -> cpu] T=$T, shape=$(size(dest))"
157155
stm = stream()
158156
Mem.download!(
159157
pointer(dest, d_offset),
@@ -171,6 +169,8 @@ function Base.copyto!(
171169
amount == 0 && return dest
172170
@boundscheck checkbounds(dest, d_offset + amount - 1)
173171
@boundscheck checkbounds(source, s_offset + amount - 1)
172+
173+
@debug "[cpu -> gpu] T=$T, shape=$(size(dest))"
174174
Mem.upload!(
175175
Mem.view(convert(Mem.AbstractAMDBuffer, dest.buf[]),
176176
(dest.offset + d_offset - 1) * sizeof(T)),

src/cache.jl

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,87 +2,99 @@
22
# Copied from CUDA.jl/lib/utils/cache.jl
33

44
# TODO:
5-
# - keep track of the (estimated?) size of cache contents
6-
# - clean the caches when memory is needed. this will require registering the destructor
7-
# upfront, so that it can set the environment (e.g. switch to the appropriate context).
8-
# alternatively, register the `unsafe_free!`` methods with the pool instead of the cache.
5+
# - store ctor/dtor in cache
6+
# - clean cache when under memory pressure
97

108
export HandleCache
119

12-
struct HandleCache{K,V}
13-
active_handles::Set{Pair{K,V}} # for debugging, and to prevent handle finalization
14-
idle_handles::Dict{K,Vector{V}}
15-
lock::ReentrantLock
10+
struct HandleCache{K, V}
11+
active_handles::Set{Pair{K, V}}
12+
idle_handles::Dict{K, Vector{V}}
13+
lock::Base.ThreadSynchronizer
14+
# TODO when finalizers are run on their own tasks use reentrant lock
1615

1716
max_entries::Int
1817

19-
function HandleCache{K,V}(max_entries::Int=32) where {K,V}
20-
return new{K,V}(Set{Pair{K,V}}(), Dict{K,Vector{V}}(), ReentrantLock(), max_entries)
18+
function HandleCache{K, V}(max_entries::Int = 32) where {K, V}
19+
new{K,V}(
20+
Set{Pair{K, V}}(),
21+
Dict{K, Vector{V}}(),
22+
Base.ThreadSynchronizer(),
23+
max_entries)
2124
end
2225
end
2326

2427
# remove a handle from the cache, or create a new one
25-
function Base.pop!(f::Function, cache::HandleCache{K,V}, key) where {K,V}
26-
function check_cache(f::Function=()->nothing)
27-
lock(cache.lock) do
28-
handle = if !haskey(cache.idle_handles, key) || isempty(cache.idle_handles[key])
29-
f()
30-
else
31-
pop!(cache.idle_handles[key])
32-
end
33-
34-
if handle !== nothing
35-
push!(cache.active_handles, key=>handle)
36-
end
37-
38-
return handle
28+
function Base.pop!(f::Function, cache::HandleCache{K, V}, key) where {K, V}
29+
# Check cache.
30+
handle, n_active_handles = Base.@lock cache.lock begin
31+
if haskey(cache.idle_handles, key) && !isempty(cache.idle_handles[key])
32+
pop!(cache.idle_handles[key]), length(cache.active_handles)
33+
else
34+
nothing, length(cache.active_handles)
3935
end
4036
end
4137

42-
handle = check_cache()
43-
44-
if handle === nothing
45-
# if we didn't find anything, perform a quick GC collection to free up old handles.
38+
# If didn't find anything, but lots of active handles - try to free some.
39+
if handle nothing && n_active_handles > cache.max_entries
4640
GC.gc(false)
47-
48-
handle = check_cache(f)
41+
Base.@lock cache.lock begin
42+
if haskey(cache.idle_handles, key) && !isempty(cache.idle_handles[key])
43+
handle = pop!(cache.idle_handles[key])
44+
end
45+
end
4946
end
5047

48+
# If still nothing, create a new handle.
49+
handle nothing && (handle = f();)
50+
51+
Base.@lock cache.lock push!(cache.active_handles, key => handle)
5152
return handle::V
5253
end
5354

5455
# put a handle in the cache, or destroy it if it doesn't fit
55-
function Base.push!(f::Function, cache::HandleCache{K,V}, key::K, handle::V) where {K,V}
56-
lock(cache.lock) do
57-
delete!(cache.active_handles, key=>handle)
56+
function Base.push!(f::Function, cache::HandleCache{K, V}, key::K, handle::V) where {K, V}
57+
saved = Base.@lock cache.lock begin
58+
(key => handle) cache.active_handles && error(
59+
"""Trying to free active handle that is not managed by cache.
60+
- Key: $key
61+
- Handle: $handle
62+
""")
63+
delete!(cache.active_handles, key => handle)
5864

5965
if haskey(cache.idle_handles, key)
6066
if length(cache.idle_handles[key]) > cache.max_entries
61-
f()
67+
false
6268
else
6369
push!(cache.idle_handles[key], handle)
70+
true
6471
end
6572
else
6673
cache.idle_handles[key] = [handle]
74+
true
6775
end
6876
end
77+
78+
saved || f()
79+
return
6980
end
7081

7182
# shorthand version to put a handle back without having to remember the key
72-
function Base.push!(f::Function, cache::HandleCache{K,V}, handle::V) where {K,V}
73-
lock(cache.lock) do
83+
function Base.push!(f::Function, cache::HandleCache{K, V}, handle::V) where {K, V}
84+
key = Base.@lock cache.lock begin
7485
key = nothing
7586
for entry in cache.active_handles
7687
if entry[2] == handle
7788
key = entry[1]
7889
break
7990
end
8091
end
81-
if key === nothing
82-
error("Attempt to cache handle $handle that was not created by the handle cache")
83-
end
84-
push!(f, cache, key, handle)
92+
93+
key nothing && error(
94+
"Attempt to cache handle $handle that was not created by the handle cache")
95+
key
8596
end
97+
push!(f, cache, key, handle)
8698
end
8799

88100
# Copied from CUDA.jl/lib/cublas/CUBLAS.jl

src/cache_allocator.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
const ROCCacheAllocator = GPUArrays.PerDeviceCacheAllocator(ROCArray; free_immediately=false)
2+
3+
GPUArrays.cache_allocator(::ROCBackend) = ROCCacheAllocator
4+
5+
GPUArrays.device(::ROCBackend) = AMDGPU.device()

0 commit comments

Comments
 (0)