Skip to content

Commit c869cc3

Browse files
pxl-thmaleadt
andauthored
Add an allocator cache (#576)
Co-authored-by: Tim Besard <[email protected]>
1 parent 5045f13 commit c869cc3

File tree

8 files changed

+232
-8
lines changed

8 files changed

+232
-8
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
14+
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
1415
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
1516
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617

@@ -23,6 +24,7 @@ LinearAlgebra = "1"
2324
Printf = "1"
2425
Random = "1"
2526
Reexport = "1"
27+
ScopedValues = "1"
2628
Serialization = "1"
2729
Statistics = "1"
2830
julia = "1.10"

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ function main()
2020
"Test suite" => "testsuite.md",
2121
],
2222
doctest = true,
23+
warnonly = [:missing_docs],
2324
)
2425

2526
deploydocs(

docs/src/interface.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Device functionality is then handled by [KernelAbstractions.jl](https://github.c
1010

1111
You should provide an array type that builds on the `AbstractGPUArray` supertype, such as:
1212

13-
```
13+
```julia
1414
mutable struct CustomArray{T, N} <: AbstractGPUArray{T, N}
1515
data::DataRef{Vector{UInt8}}
1616
offset::Int
@@ -23,10 +23,17 @@ end
2323
This will allow your defined type (in this case `JLArray`) to use the GPUArrays interface where available.
2424
To be able to actually use the functionality that is defined for `AbstractGPUArray`s, you need to define the backend, like so:
2525

26-
```
26+
```julia
2727
import KernelAbstractions: Backend
2828
struct CustomBackend <: KernelAbstractions.GPU
2929
KernelAbstractions.get_backend(a::CA) where CA <: CustomArray = CustomBackend()
3030
```
3131

3232
There are numerous examples of potential interfaces for GPUArrays, such as with [JLArrays](https://github.com/JuliaGPU/GPUArrays.jl/blob/master/lib/JLArrays/src/JLArrays.jl), [CuArrays](https://github.com/JuliaGPU/CUDA.jl/blob/master/src/gpuarrays.jl), and [ROCArrays](https://github.com/JuliaGPU/AMDGPU.jl/blob/master/src/gpuarrays.jl).
33+
34+
## Caching Allocator
35+
36+
```@docs
37+
GPUArrays.@cached
38+
GPUArrays.@uncached
39+
```

lib/JLArrays/src/JLArrays.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,16 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
8888
function JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
8989
check_eltype(T)
9090
maxsize = prod(dims) * sizeof(T)
91-
data = Vector{UInt8}(undef, maxsize)
92-
ref = DataRef(data) do data
93-
resize!(data, 0)
94-
end
95-
obj = new{T,N}(ref, 0, dims)
96-
finalizer(unsafe_free!, obj)
91+
92+
return GPUArrays.cached_alloc((JLArray, T, dims)) do
93+
data = Vector{UInt8}(undef, maxsize)
94+
ref = DataRef(data) do data
95+
resize!(data, 0)
96+
end
97+
obj = new{T, N}(ref, 0, dims)
98+
finalizer(unsafe_free!, obj)
99+
return obj
100+
end::JLArray{T, N}
97101
end
98102

99103
# low-level constructor for wrapping existing data
@@ -102,6 +106,7 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
102106
check_eltype(T)
103107
obj = new{T,N}(ref, offset, dims)
104108
finalizer(unsafe_free!, obj)
109+
return obj
105110
end
106111
end
107112

src/GPUArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ include("host/random.jl")
3434
include("host/quirks.jl")
3535
include("host/uniformscaling.jl")
3636
include("host/statistics.jl")
37+
include("host/alloc_cache.jl")
3738

3839

3940
end # module

src/host/alloc_cache.jl

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
using ..GPUArrays
2+
3+
@static if VERSION < v"1.11"
4+
using ScopedValues
5+
else
6+
using Base.ScopedValues
7+
end
8+
9+
mutable struct AllocCache
10+
lock::ReentrantLock
11+
busy::Dict{UInt64, Vector{Any}} # hash(key) => GPUArray[]
12+
free::Dict{UInt64, Vector{Any}}
13+
14+
function AllocCache()
15+
cache = new(
16+
ReentrantLock(),
17+
Dict{UInt64, Vector{Any}}(),
18+
Dict{UInt64, Vector{Any}}()
19+
)
20+
return finalizer(unsafe_free!, cache)
21+
end
22+
end
23+
24+
function get_pool!(cache::AllocCache, pool::Symbol, uid::UInt64)
25+
pool = getproperty(cache, pool)
26+
uid_pool = get(pool, uid, nothing)
27+
if uid_pool nothing
28+
uid_pool = Base.@lock cache.lock pool[uid] = Any[]
29+
end
30+
return uid_pool
31+
end
32+
33+
function cached_alloc(f, key)
34+
cache = ALLOC_CACHE[]
35+
if cache === nothing
36+
return f()
37+
end
38+
39+
x = nothing
40+
uid = hash(key)
41+
42+
busy_pool = get_pool!(cache, :busy, uid)
43+
free_pool = get_pool!(cache, :free, uid)
44+
isempty(free_pool) && (x = f())
45+
46+
while !isempty(free_pool) && x nothing
47+
tmp = Base.@lock cache.lock pop!(free_pool)
48+
# Array was manually freed via `unsafe_free!`.
49+
GPUArrays.storage(tmp).freed && continue
50+
x = tmp
51+
end
52+
53+
x nothing && (x = f())
54+
Base.@lock cache.lock push!(busy_pool, x)
55+
return x
56+
end
57+
58+
function free_busy!(cache::AllocCache)
59+
for uid in cache.busy.keys
60+
busy_pool = get_pool!(cache, :busy, uid)
61+
isempty(busy_pool) && continue
62+
63+
Base.@lock cache.lock begin
64+
free_pool = get_pool!(cache, :free, uid)
65+
append!(free_pool, busy_pool)
66+
empty!(busy_pool)
67+
end
68+
end
69+
return
70+
end
71+
72+
function unsafe_free!(cache::AllocCache)
73+
Base.@lock cache.lock begin
74+
for (_, pool) in cache.busy
75+
isempty(pool) || error(
76+
"Invalidating allocations cache that's currently in use. " *
77+
"Invalidating inside `@cached` is not allowed."
78+
)
79+
end
80+
for (_, pool) in cache.free
81+
map(unsafe_free!, pool)
82+
end
83+
empty!(cache.free)
84+
end
85+
return
86+
end
87+
88+
function Base.sizeof(cache::AllocCache)
89+
sz = UInt64(0)
90+
Base.@lock cache.lock begin
91+
for kind in (cache.free, cache.busy), (_, pool) in kind
92+
sz += sum(sizeof, pool; init = UInt64(0))
93+
end
94+
end
95+
return sz
96+
end
97+
98+
function Base.show(io::IO, cache::AllocCache)
99+
sz, n_free, n_busy = Base.@lock cache.lock begin
100+
sz = sizeof(cache)
101+
n_free = sum(p -> length(p[2]), cache.free; init = 0)
102+
n_busy = sum(p -> length(p[2]), cache.busy; init = 0)
103+
sz, n_free, n_busy
104+
end
105+
return print(io, "AllocCache(n_free=$n_free, n_busy=$n_busy, sizeof=$(Base.format_bytes(sz)))")
106+
end
107+
108+
const ALLOC_CACHE = ScopedValue{Union{Nothing, AllocCache}}(nothing)
109+
110+
"""
111+
@cached(cache, expr)
112+
113+
Evaluate `expr` using allocations cache `cache`.
114+
115+
When GPU memory is allocated during the execution of `expr`, `cache` will first be checked.
116+
If no memory is available in the cache, a new allocation will be requested.
117+
118+
After the execution of `expr`, all allocations made under the scope of `@cached` will be
119+
cached within `cache` for future use. This is useful to avoid relying on GC to free GPU
120+
memory in time.
121+
122+
Once `cache` goes out scope, or when the user calls `unsafe_free!` on it, all cached
123+
allocations will be freed.
124+
125+
# Example
126+
127+
In the following example, each iteration of the for-loop requires 8 GiB of GPU memory.
128+
Without caching those allocations, significant pressure would be put on the GC, resulting
129+
in high memory usage and latency. By using the allocator cache, the memory usage is stable:
130+
131+
```julia
132+
cache = GPUArrays.AllocCache()
133+
for i in 1:1000
134+
GPUArrays.@cached cache begin
135+
sin.(CUDA.rand(Float32, 1024^3))
136+
end
137+
end
138+
139+
# optionally: free the memory now, instead of waiting for the GC to collect `cache`
140+
GPUArrays.unsafe_free!(cache)
141+
```
142+
143+
See [`@uncached`](@ref).
144+
"""
145+
macro cached(cache, expr)
146+
return quote
147+
res = @with $(esc(ALLOC_CACHE)) => $(esc(cache)) $(esc(expr))
148+
free_busy!($(esc(cache)))
149+
res
150+
end
151+
end
152+
153+
"""
154+
uncached(expr)
155+
156+
Evaluate expression `expr` without using the allocation. This is useful to call from within
157+
`@cached` to avoid caching some allocations, e.g., because they can be returned out of the
158+
`@cached` scope.
159+
"""
160+
macro uncached(expr)
161+
return quote
162+
@with $(esc(ALLOC_CACHE)) => nothing $(esc(expr))
163+
end
164+
end

test/testsuite.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ include("testsuite/math.jl")
9393
include("testsuite/random.jl")
9494
include("testsuite/uniformscaling.jl")
9595
include("testsuite/statistics.jl")
96+
include("testsuite/alloc_cache.jl")
9697

9798
"""
9899
Runs the entire GPUArrays test suite on array type `AT`

test/testsuite/alloc_cache.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
@testsuite "alloc cache" (AT, eltypes) -> begin
2+
if AT <: AbstractGPUArray
3+
cache = GPUArrays.AllocCache()
4+
5+
T, dims = Float32, (1, 2, 3)
6+
GPUArrays.@cached cache begin
7+
x1 = AT(zeros(T, dims))
8+
end
9+
@test sizeof(cache) == sizeof(T) * prod(dims)
10+
key = first(keys(cache.free))
11+
@test length(cache.free[key]) == 1
12+
@test length(cache.busy[key]) == 0
13+
@test x1 === cache.free[key][1]
14+
15+
# Second allocation hits cache.
16+
GPUArrays.@cached cache begin
17+
x2 = AT(zeros(T, dims))
18+
# Does not hit the cache.
19+
GPUArrays.@uncached x_free = AT(zeros(T, dims))
20+
end
21+
@test sizeof(cache) == sizeof(T) * prod(dims)
22+
key = first(keys(cache.free))
23+
@test length(cache.free[key]) == 1
24+
@test length(cache.busy[key]) == 0
25+
@test x2 === cache.free[key][1]
26+
@test x_free !== x2
27+
28+
# Third allocation is of different shape - allocates.
29+
dims = (2, 2)
30+
GPUArrays.@cached cache begin
31+
x3 = AT(zeros(T, dims))
32+
end
33+
_keys = collect(keys(cache.free))
34+
key2 = _keys[findfirst(i -> i != key, _keys)]
35+
@test length(cache.free[key]) == 1
36+
@test length(cache.free[key2]) == 1
37+
@test x3 === cache.free[key2][1]
38+
39+
# Freeing all memory held by cache.
40+
GPUArrays.unsafe_free!(cache)
41+
@test sizeof(cache) == 0
42+
end
43+
end

0 commit comments

Comments
 (0)