Skip to content

Commit 07571f8

Browse files
committed
Teach DataRef about cached references to support unsafe_free.
1 parent 41a7457 commit 07571f8

File tree

3 files changed

+52
-41
lines changed

3 files changed

+52
-41
lines changed

src/host/abstractarray.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,12 @@ end
5757
mutable struct DataRef{D}
5858
rc::RefCounted{D}
5959
freed::Bool
60+
cached::Bool
6061
end
6162

6263
function DataRef(finalizer, ref::D) where {D}
6364
rc = RefCounted{D}(ref, finalizer, Threads.Atomic{Int}(1))
64-
DataRef{D}(rc, false)
65+
DataRef{D}(rc, false, false)
6566
end
6667
DataRef(ref; kwargs...) = DataRef(nothing, ref; kwargs...)
6768

@@ -79,10 +80,16 @@ function Base.copy(ref::DataRef{D}) where {D}
7980
throw(ArgumentError("Attempt to copy a freed reference."))
8081
end
8182
retain(ref.rc)
82-
return DataRef{D}(ref.rc, false)
83+
# copies of cached references are not managed by the cache, so
84+
# we need to mark them as such to make sure their refcount can drop.
85+
return DataRef{D}(ref.rc, false, false)
8386
end
8487

8588
function unsafe_free!(ref::DataRef)
89+
if ref.cached
90+
# lifetimes of cached references are tied to the cache.
91+
return
92+
end
8693
if ref.freed
8794
# multiple frees *of the same object* are allowed.
8895
# we should only ever call `release` once per object, though,

src/host/alloc_cache.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,12 @@ function cached_alloc(f, key)
4444

4545
if !isempty(free_pool)
4646
ref = Base.@lock cache.lock pop!(free_pool)
47-
@assert !ref.freed
4847
end
4948
end
5049

5150
if ref === nothing
5251
ref = f()::DataRef
53-
54-
# increase the refcount of the ref to prevent finalizers from freeing it
55-
retain(ref.rc)
52+
ref.cached = true
5653
end
5754

5855
Base.@lock cache.lock begin
@@ -80,16 +77,11 @@ end
8077
function unsafe_free!(cache::AllocCache)
8178
Base.@lock cache.lock begin
8279
for pool in values(cache.busy)
83-
isempty(pool) || error(
84-
"Invalidating allocations cache that's currently in use. " *
85-
"Invalidating inside `@cached` is not allowed."
86-
)
80+
isempty(pool) || error("Cannot invalidate a cache that's in active use")
8781
end
8882
for pool in values(cache.free), ref in pool
89-
# release our hold on the underlying data
90-
release(ref.rc)
91-
92-
# early-release the reference
83+
# release the reference
84+
ref.cached = false
9385
unsafe_free!(ref)
9486
end
9587
empty!(cache.free)

test/testsuite/alloc_cache.jl

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,75 +5,87 @@
55
# first allocation populates the cache
66
T, dims = Float32, (1, 2, 3)
77
GPUArrays.@cached cache begin
8-
x1 = AT(zeros(T, dims))
8+
cached1 = AT(zeros(T, dims))
99
end
10-
@test sizeof(cache) == sizeof(T) * prod(dims)
10+
@test sizeof(cache) == sizeof(cached1)
1111
key = first(keys(cache.free))
1212
@test length(cache.free[key]) == 1
1313
@test length(cache.busy[key]) == 0
14-
@test cache.free[key][1] === GPUArrays.storage(x1)
14+
@test cache.free[key][1] === GPUArrays.storage(cached1)
1515

1616
# second allocation hits the cache
1717
GPUArrays.@cached cache begin
18-
x2 = AT(zeros(T, dims))
18+
cached2 = AT(zeros(T, dims))
1919

2020
# explicitly uncached ones don't
21-
GPUArrays.@uncached x_free = AT(zeros(T, dims))
21+
GPUArrays.@uncached uncached = AT(zeros(T, dims))
2222
end
23-
@test sizeof(cache) == sizeof(T) * prod(dims)
23+
@test sizeof(cache) == sizeof(cached2)
2424
key = first(keys(cache.free))
2525
@test length(cache.free[key]) == 1
2626
@test length(cache.busy[key]) == 0
27-
@test cache.free[key][1] === GPUArrays.storage(x2)
28-
@test x_free !== x2
27+
@test cache.free[key][1] === GPUArrays.storage(cached2)
28+
@test uncached !== cached2
2929

3030
# compatible shapes should also hit the cache
3131
dims = (3, 2, 1)
3232
GPUArrays.@cached cache begin
33-
x3 = AT(zeros(T, dims))
33+
cached3 = AT(zeros(T, dims))
3434
end
35-
@test sizeof(cache) == sizeof(T) * prod(dims)
35+
@test sizeof(cache) == sizeof(cached3)
3636
key = first(keys(cache.free))
3737
@test length(cache.free[key]) == 1
3838
@test length(cache.busy[key]) == 0
39-
@test cache.free[key][1] === GPUArrays.storage(x3)
39+
@test cache.free[key][1] === GPUArrays.storage(cached3)
4040

4141
# as should compatible eltypes
4242
T = Int32
4343
GPUArrays.@cached cache begin
44-
x4 = AT(zeros(T, dims))
44+
cached4 = AT(zeros(T, dims))
4545
end
46-
@test sizeof(cache) == sizeof(T) * prod(dims)
46+
@test sizeof(cache) == sizeof(cached4)
4747
key = first(keys(cache.free))
4848
@test length(cache.free[key]) == 1
4949
@test length(cache.busy[key]) == 0
50-
@test cache.free[key][1] === GPUArrays.storage(x4)
50+
@test cache.free[key][1] === GPUArrays.storage(cached4)
5151

5252
# different shapes should trigger a new allocation
5353
dims = (2, 2)
5454
GPUArrays.@cached cache begin
55-
x5 = AT(zeros(T, dims))
55+
cached5 = AT(zeros(T, dims))
5656

57-
# we're allowed to early free arrays, which shouldn't release the underlying data
58-
GPUArrays.unsafe_free!(x5)
57+
# we're allowed to early free arrays, which should be a no-op for cached data
58+
GPUArrays.unsafe_free!(cached5)
5959
end
60+
@test sizeof(cache) == sizeof(cached4) + sizeof(cached5)
6061
_keys = collect(keys(cache.free))
6162
key2 = _keys[findfirst(i -> i != key, _keys)]
6263
@test length(cache.free[key]) == 1
6364
@test length(cache.free[key2]) == 1
64-
@test cache.free[key2][1] === GPUArrays.storage(x5)
65+
@test cache.free[key2][1] === GPUArrays.storage(cached5)
66+
67+
# we should be able to re-use the early-freed
68+
GPUArrays.@cached cache begin
69+
cached5 = AT(zeros(T, dims))
70+
end
6571

6672
# freeing all memory held by cache should free all allocations
67-
@test !GPUArrays.storage(x1).freed
68-
@test GPUArrays.storage(x5).freed
69-
@test GPUArrays.storage(x5).rc.count[] == 1 # the ref appears freed, but the data isn't
70-
@test !GPUArrays.storage(x_free).freed
73+
@test !GPUArrays.storage(cached1).freed
74+
@test GPUArrays.storage(cached1).cached
75+
@test !GPUArrays.storage(cached5).freed
76+
@test GPUArrays.storage(cached5).cached
77+
@test !GPUArrays.storage(uncached).freed
78+
@test !GPUArrays.storage(uncached).cached
7179
GPUArrays.unsafe_free!(cache)
7280
@test sizeof(cache) == 0
73-
@test GPUArrays.storage(x1).freed
74-
@test GPUArrays.storage(x1).rc.count[] == 0
75-
@test GPUArrays.storage(x5).freed
76-
@test GPUArrays.storage(x5).rc.count[] == 0
77-
@test !GPUArrays.storage(x_free).freed
81+
@test GPUArrays.storage(cached1).freed
82+
@test !GPUArrays.storage(cached1).cached
83+
@test GPUArrays.storage(cached5).freed
84+
@test !GPUArrays.storage(cached5).cached
85+
@test !GPUArrays.storage(uncached).freed
86+
## test that the underlying data was freed as well
87+
@test GPUArrays.storage(cached1).rc.count[] == 0
88+
@test GPUArrays.storage(cached5).rc.count[] == 0
89+
@test GPUArrays.storage(uncached).rc.count[] == 1
7890
end
7991
end

0 commit comments

Comments
 (0)