From 67606385b26547c95afa989db001a2a97c2abf76 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 10 Jan 2025 10:12:08 +0100 Subject: [PATCH 1/6] Add sizeof to DataRef. --- src/host/abstractarray.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/host/abstractarray.jl b/src/host/abstractarray.jl index 150cd88b6..0ef13044c 100644 --- a/src/host/abstractarray.jl +++ b/src/host/abstractarray.jl @@ -53,17 +53,19 @@ end # per-object state, with a flag to indicate whether the object has been freed. # this is to support multiple calls to `unsafe_free!` on the same object, -# while only lowering the referene count of the underlying data once. +# while only lowering the reference count of the underlying data once. mutable struct DataRef{D} rc::RefCounted{D} freed::Bool end -function DataRef(finalizer, data::D) where {D} - rc = RefCounted{D}(data, finalizer, Threads.Atomic{Int}(1)) +function DataRef(finalizer, ref::D) where {D} + rc = RefCounted{D}(ref, finalizer, Threads.Atomic{Int}(1)) DataRef{D}(rc, false) end -DataRef(data; kwargs...) = DataRef(nothing, data; kwargs...) +DataRef(ref; kwargs...) = DataRef(nothing, ref; kwargs...) + +Base.sizeof(ref::DataRef) = sizeof(ref.rc[]) function Base.getindex(ref::DataRef) if ref.freed From 319551e7c519bd1fc3762a98f0ac5c0a49eaf099 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 10 Jan 2025 10:12:23 +0100 Subject: [PATCH 2/6] Remove unused multi-arg unsafe_free. --- src/host/abstractarray.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/host/abstractarray.jl b/src/host/abstractarray.jl index 0ef13044c..0806b61ea 100644 --- a/src/host/abstractarray.jl +++ b/src/host/abstractarray.jl @@ -82,7 +82,7 @@ function Base.copy(ref::DataRef{D}) where {D} return DataRef{D}(ref.rc, false) end -function unsafe_free!(ref::DataRef, args...) +function unsafe_free!(ref::DataRef) if ref.freed # multiple frees *of the same object* are allowed. # we should only ever call `release` once per object, though, @@ -90,7 +90,7 @@ function unsafe_free!(ref::DataRef, args...) return end ref.freed = true - release(ref.rc, args...) + release(ref.rc) return end From 98ad07dc5a46f95c783882fd40be3605639922b2 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 10 Jan 2025 10:13:00 +0100 Subject: [PATCH 3/6] Cache DataRefs instead of arrays. --- lib/JLArrays/src/JLArrays.jl | 13 +++++---- src/host/alloc_cache.jl | 45 +++++++++++++++++------------- test/testsuite/alloc_cache.jl | 52 +++++++++++++++++++++++++++++------ 3 files changed, 77 insertions(+), 33 deletions(-) diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl index d36e9af2a..6e5d38df7 100644 --- a/lib/JLArrays/src/JLArrays.jl +++ b/lib/JLArrays/src/JLArrays.jl @@ -89,15 +89,16 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N} check_eltype(T) maxsize = prod(dims) * sizeof(T) - return GPUArrays.cached_alloc((JLArray, T, dims)) do + ref = GPUArrays.cached_alloc((JLArray, maxsize)) do data = Vector{UInt8}(undef, maxsize) - ref = DataRef(data) do data + DataRef(data) do data resize!(data, 0) end - obj = new{T, N}(ref, 0, dims) - finalizer(unsafe_free!, obj) - return obj - end::JLArray{T, N} + end + + obj = new{T, N}(ref, 0, dims) + finalizer(unsafe_free!, obj) + return obj end # low-level constructor for wrapping existing data diff --git a/src/host/alloc_cache.jl b/src/host/alloc_cache.jl index 22775b2dd..dd39cfc90 100644 --- a/src/host/alloc_cache.jl +++ b/src/host/alloc_cache.jl @@ -8,8 +8,8 @@ end mutable struct AllocCache lock::ReentrantLock - busy::Dict{UInt64, Vector{Any}} # hash(key) => GPUArray[] - free::Dict{UInt64, Vector{Any}} + busy::Dict{UInt64, Vector{DataRef}} + free::Dict{UInt64, Vector{DataRef}} function AllocCache() cache = new( @@ -24,8 +24,8 @@ end function get_pool!(cache::AllocCache, pool::Symbol, uid::UInt64) pool = getproperty(cache, pool) uid_pool = get(pool, uid, nothing) - if uid_pool ≡ nothing - uid_pool = Base.@lock cache.lock pool[uid] = Any[] + if uid_pool === nothing + uid_pool = Base.@lock cache.lock pool[uid] = DataRef[] end return uid_pool end @@ -33,30 +33,33 @@ end function cached_alloc(f, key) cache = ALLOC_CACHE[] if cache === nothing - return f()::AbstractGPUArray + return f()::DataRef end - x = nothing + ref = nothing uid = hash(key) busy_pool = get_pool!(cache, :busy, uid) free_pool = get_pool!(cache, :free, uid) - isempty(free_pool) && (x = f()::AbstractGPUArray) - while !isempty(free_pool) && x ≡ nothing - tmp = Base.@lock cache.lock pop!(free_pool) - # Array was manually freed via `unsafe_free!`. - GPUArrays.storage(tmp).freed && continue - x = tmp + if !isempty(free_pool) + ref = Base.@lock cache.lock pop!(free_pool) + @assert !ref.freed end - x ≡ nothing && (x = f()::AbstractGPUArray) - Base.@lock cache.lock push!(busy_pool, x) - return x + if ref === nothing + ref = f()::DataRef + + # increase the refcount of the ref to prevent finalizers from freeing it + retain(ref.rc) + end + + Base.@lock cache.lock push!(busy_pool, ref) + return ref end function free_busy!(cache::AllocCache) - for uid in cache.busy.keys + for uid in keys(cache.busy) busy_pool = get_pool!(cache, :busy, uid) isempty(busy_pool) && continue @@ -71,14 +74,18 @@ end function unsafe_free!(cache::AllocCache) Base.@lock cache.lock begin - for (_, pool) in cache.busy + for pool in values(cache.busy) isempty(pool) || error( "Invalidating allocations cache that's currently in use. " * "Invalidating inside `@cached` is not allowed." ) end - for (_, pool) in cache.free - map(unsafe_free!, pool) + for pool in values(cache.free), ref in pool + # release our hold on the underlying data + release(ref.rc) + + # early-release the reference + unsafe_free!(ref) end empty!(cache.free) end diff --git a/test/testsuite/alloc_cache.jl b/test/testsuite/alloc_cache.jl index b032c8bda..e63d99e04 100644 --- a/test/testsuite/alloc_cache.jl +++ b/test/testsuite/alloc_cache.jl @@ -2,6 +2,7 @@ if AT <: AbstractGPUArray cache = GPUArrays.AllocCache() + # first allocation populates the cache T, dims = Float32, (1, 2, 3) GPUArrays.@cached cache begin x1 = AT(zeros(T, dims)) @@ -10,34 +11,69 @@ key = first(keys(cache.free)) @test length(cache.free[key]) == 1 @test length(cache.busy[key]) == 0 - @test x1 === cache.free[key][1] + @test cache.free[key][1] === GPUArrays.storage(x1) - # Second allocation hits cache. + # second allocation hits the cache GPUArrays.@cached cache begin x2 = AT(zeros(T, dims)) - # Does not hit the cache. + + # explicitly uncached ones don't GPUArrays.@uncached x_free = AT(zeros(T, dims)) end @test sizeof(cache) == sizeof(T) * prod(dims) key = first(keys(cache.free)) @test length(cache.free[key]) == 1 @test length(cache.busy[key]) == 0 - @test x2 === cache.free[key][1] + @test cache.free[key][1] === GPUArrays.storage(x2) @test x_free !== x2 - # Third allocation is of different shape - allocates. - dims = (2, 2) + # compatible shapes should also hit the cache + dims = (3, 2, 1) GPUArrays.@cached cache begin x3 = AT(zeros(T, dims)) end + @test sizeof(cache) == sizeof(T) * prod(dims) + key = first(keys(cache.free)) + @test length(cache.free[key]) == 1 + @test length(cache.busy[key]) == 0 + @test cache.free[key][1] === GPUArrays.storage(x3) + + # as should compatible eltypes + T = Int32 + GPUArrays.@cached cache begin + x4 = AT(zeros(T, dims)) + end + @test sizeof(cache) == sizeof(T) * prod(dims) + key = first(keys(cache.free)) + @test length(cache.free[key]) == 1 + @test length(cache.busy[key]) == 0 + @test cache.free[key][1] === GPUArrays.storage(x4) + + # different shapes should trigger a new allocation + dims = (2, 2) + GPUArrays.@cached cache begin + x5 = AT(zeros(T, dims)) + + # we're allowed to early free arrays, which shouldn't release the underlying data + GPUArrays.unsafe_free!(x5) + end _keys = collect(keys(cache.free)) key2 = _keys[findfirst(i -> i != key, _keys)] @test length(cache.free[key]) == 1 @test length(cache.free[key2]) == 1 - @test x3 === cache.free[key2][1] + @test cache.free[key2][1] === GPUArrays.storage(x5) - # Freeing all memory held by cache. + # freeing all memory held by cache should free all allocations + @test !GPUArrays.storage(x1).freed + @test GPUArrays.storage(x5).freed + @test GPUArrays.storage(x5).rc.count[] == 1 # the ref appears freed, but the data isn't + @test !GPUArrays.storage(x_free).freed GPUArrays.unsafe_free!(cache) @test sizeof(cache) == 0 + @test GPUArrays.storage(x1).freed + @test GPUArrays.storage(x1).rc.count[] == 0 + @test GPUArrays.storage(x5).freed + @test GPUArrays.storage(x5).rc.count[] == 0 + @test !GPUArrays.storage(x_free).freed end end From 41a745710948ac9ab2aec1654fbd73ebe50181d0 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 10 Jan 2025 10:20:55 +0100 Subject: [PATCH 4/6] Lock across more cache operations. It's not safe to concurrently access a dict while it may be mutated. --- src/host/alloc_cache.jl | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/host/alloc_cache.jl b/src/host/alloc_cache.jl index dd39cfc90..c764bea49 100644 --- a/src/host/alloc_cache.jl +++ b/src/host/alloc_cache.jl @@ -25,7 +25,7 @@ function get_pool!(cache::AllocCache, pool::Symbol, uid::UInt64) pool = getproperty(cache, pool) uid_pool = get(pool, uid, nothing) if uid_pool === nothing - uid_pool = Base.@lock cache.lock pool[uid] = DataRef[] + uid_pool = pool[uid] = DataRef[] end return uid_pool end @@ -39,12 +39,13 @@ function cached_alloc(f, key) ref = nothing uid = hash(key) - busy_pool = get_pool!(cache, :busy, uid) - free_pool = get_pool!(cache, :free, uid) + Base.@lock cache.lock begin + free_pool = get_pool!(cache, :free, uid) - if !isempty(free_pool) - ref = Base.@lock cache.lock pop!(free_pool) - @assert !ref.freed + if !isempty(free_pool) + ref = Base.@lock cache.lock pop!(free_pool) + @assert !ref.freed + end end if ref === nothing @@ -54,16 +55,20 @@ function cached_alloc(f, key) retain(ref.rc) end - Base.@lock cache.lock push!(busy_pool, ref) + Base.@lock cache.lock begin + busy_pool = get_pool!(cache, :busy, uid) + push!(busy_pool, ref) + end + return ref end function free_busy!(cache::AllocCache) - for uid in keys(cache.busy) - busy_pool = get_pool!(cache, :busy, uid) - isempty(busy_pool) && continue + Base.@lock cache.lock begin + for uid in keys(cache.busy) + busy_pool = get_pool!(cache, :busy, uid) + isempty(busy_pool) && continue - Base.@lock cache.lock begin free_pool = get_pool!(cache, :free, uid) append!(free_pool, busy_pool) empty!(busy_pool) From 07571f8c56e6113ba2565c9401cd609ed495aee2 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 13 Jan 2025 18:13:34 +0100 Subject: [PATCH 5/6] Teach DataRef about cached references to support unsafe_free. --- src/host/abstractarray.jl | 11 ++++-- src/host/alloc_cache.jl | 16 +++------ test/testsuite/alloc_cache.jl | 66 +++++++++++++++++++++-------------- 3 files changed, 52 insertions(+), 41 deletions(-) diff --git a/src/host/abstractarray.jl b/src/host/abstractarray.jl index 0806b61ea..881e6e88c 100644 --- a/src/host/abstractarray.jl +++ b/src/host/abstractarray.jl @@ -57,11 +57,12 @@ end mutable struct DataRef{D} rc::RefCounted{D} freed::Bool + cached::Bool end function DataRef(finalizer, ref::D) where {D} rc = RefCounted{D}(ref, finalizer, Threads.Atomic{Int}(1)) - DataRef{D}(rc, false) + DataRef{D}(rc, false, false) end DataRef(ref; kwargs...) = DataRef(nothing, ref; kwargs...) @@ -79,10 +80,16 @@ function Base.copy(ref::DataRef{D}) where {D} throw(ArgumentError("Attempt to copy a freed reference.")) end retain(ref.rc) - return DataRef{D}(ref.rc, false) + # copies of cached references are not managed by the cache, so + # we need to mark them as such to make sure their refcount can drop. + return DataRef{D}(ref.rc, false, false) end function unsafe_free!(ref::DataRef) + if ref.cached + # lifetimes of cached references are tied to the cache. + return + end if ref.freed # multiple frees *of the same object* are allowed. # we should only ever call `release` once per object, though, diff --git a/src/host/alloc_cache.jl b/src/host/alloc_cache.jl index c764bea49..a93722472 100644 --- a/src/host/alloc_cache.jl +++ b/src/host/alloc_cache.jl @@ -44,15 +44,12 @@ function cached_alloc(f, key) if !isempty(free_pool) ref = Base.@lock cache.lock pop!(free_pool) - @assert !ref.freed end end if ref === nothing ref = f()::DataRef - - # increase the refcount of the ref to prevent finalizers from freeing it - retain(ref.rc) + ref.cached = true end Base.@lock cache.lock begin @@ -80,16 +77,11 @@ end function unsafe_free!(cache::AllocCache) Base.@lock cache.lock begin for pool in values(cache.busy) - isempty(pool) || error( - "Invalidating allocations cache that's currently in use. " * - "Invalidating inside `@cached` is not allowed." - ) + isempty(pool) || error("Cannot invalidate a cache that's in active use") end for pool in values(cache.free), ref in pool - # release our hold on the underlying data - release(ref.rc) - - # early-release the reference + # release the reference + ref.cached = false unsafe_free!(ref) end empty!(cache.free) diff --git a/test/testsuite/alloc_cache.jl b/test/testsuite/alloc_cache.jl index e63d99e04..6eae52243 100644 --- a/test/testsuite/alloc_cache.jl +++ b/test/testsuite/alloc_cache.jl @@ -5,75 +5,87 @@ # first allocation populates the cache T, dims = Float32, (1, 2, 3) GPUArrays.@cached cache begin - x1 = AT(zeros(T, dims)) + cached1 = AT(zeros(T, dims)) end - @test sizeof(cache) == sizeof(T) * prod(dims) + @test sizeof(cache) == sizeof(cached1) key = first(keys(cache.free)) @test length(cache.free[key]) == 1 @test length(cache.busy[key]) == 0 - @test cache.free[key][1] === GPUArrays.storage(x1) + @test cache.free[key][1] === GPUArrays.storage(cached1) # second allocation hits the cache GPUArrays.@cached cache begin - x2 = AT(zeros(T, dims)) + cached2 = AT(zeros(T, dims)) # explicitly uncached ones don't - GPUArrays.@uncached x_free = AT(zeros(T, dims)) + GPUArrays.@uncached uncached = AT(zeros(T, dims)) end - @test sizeof(cache) == sizeof(T) * prod(dims) + @test sizeof(cache) == sizeof(cached2) key = first(keys(cache.free)) @test length(cache.free[key]) == 1 @test length(cache.busy[key]) == 0 - @test cache.free[key][1] === GPUArrays.storage(x2) - @test x_free !== x2 + @test cache.free[key][1] === GPUArrays.storage(cached2) + @test uncached !== cached2 # compatible shapes should also hit the cache dims = (3, 2, 1) GPUArrays.@cached cache begin - x3 = AT(zeros(T, dims)) + cached3 = AT(zeros(T, dims)) end - @test sizeof(cache) == sizeof(T) * prod(dims) + @test sizeof(cache) == sizeof(cached3) key = first(keys(cache.free)) @test length(cache.free[key]) == 1 @test length(cache.busy[key]) == 0 - @test cache.free[key][1] === GPUArrays.storage(x3) + @test cache.free[key][1] === GPUArrays.storage(cached3) # as should compatible eltypes T = Int32 GPUArrays.@cached cache begin - x4 = AT(zeros(T, dims)) + cached4 = AT(zeros(T, dims)) end - @test sizeof(cache) == sizeof(T) * prod(dims) + @test sizeof(cache) == sizeof(cached4) key = first(keys(cache.free)) @test length(cache.free[key]) == 1 @test length(cache.busy[key]) == 0 - @test cache.free[key][1] === GPUArrays.storage(x4) + @test cache.free[key][1] === GPUArrays.storage(cached4) # different shapes should trigger a new allocation dims = (2, 2) GPUArrays.@cached cache begin - x5 = AT(zeros(T, dims)) + cached5 = AT(zeros(T, dims)) - # we're allowed to early free arrays, which shouldn't release the underlying data - GPUArrays.unsafe_free!(x5) + # we're allowed to early free arrays, which should be a no-op for cached data + GPUArrays.unsafe_free!(cached5) end + @test sizeof(cache) == sizeof(cached4) + sizeof(cached5) _keys = collect(keys(cache.free)) key2 = _keys[findfirst(i -> i != key, _keys)] @test length(cache.free[key]) == 1 @test length(cache.free[key2]) == 1 - @test cache.free[key2][1] === GPUArrays.storage(x5) + @test cache.free[key2][1] === GPUArrays.storage(cached5) + + # we should be able to re-use the early-freed + GPUArrays.@cached cache begin + cached5 = AT(zeros(T, dims)) + end # freeing all memory held by cache should free all allocations - @test !GPUArrays.storage(x1).freed - @test GPUArrays.storage(x5).freed - @test GPUArrays.storage(x5).rc.count[] == 1 # the ref appears freed, but the data isn't - @test !GPUArrays.storage(x_free).freed + @test !GPUArrays.storage(cached1).freed + @test GPUArrays.storage(cached1).cached + @test !GPUArrays.storage(cached5).freed + @test GPUArrays.storage(cached5).cached + @test !GPUArrays.storage(uncached).freed + @test !GPUArrays.storage(uncached).cached GPUArrays.unsafe_free!(cache) @test sizeof(cache) == 0 - @test GPUArrays.storage(x1).freed - @test GPUArrays.storage(x1).rc.count[] == 0 - @test GPUArrays.storage(x5).freed - @test GPUArrays.storage(x5).rc.count[] == 0 - @test !GPUArrays.storage(x_free).freed + @test GPUArrays.storage(cached1).freed + @test !GPUArrays.storage(cached1).cached + @test GPUArrays.storage(cached5).freed + @test !GPUArrays.storage(cached5).cached + @test !GPUArrays.storage(uncached).freed + ## test that the underlying data was freed as well + @test GPUArrays.storage(cached1).rc.count[] == 0 + @test GPUArrays.storage(cached5).rc.count[] == 0 + @test GPUArrays.storage(uncached).rc.count[] == 1 end end From 8a23d20865cc0a4449bf4e1663e55882f4be3e94 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 15 Jan 2025 09:23:13 +0100 Subject: [PATCH 6/6] Free the cache from the at-cached finalizer. --- src/host/alloc_cache.jl | 10 ++++------ test/testsuite/alloc_cache.jl | 8 ++++++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/host/alloc_cache.jl b/src/host/alloc_cache.jl index a93722472..442f375bb 100644 --- a/src/host/alloc_cache.jl +++ b/src/host/alloc_cache.jl @@ -147,13 +147,11 @@ GPUArrays.unsafe_free!(cache) See [`@uncached`](@ref). """ macro cached(cache, expr) + try_expr = :(@with $(esc(ALLOC_CACHE)) => cache $(esc(expr))) + fin_expr = :(free_busy!($(esc(cache)))) return quote - cache = $(esc(cache)) - GC.@preserve cache begin - res = @with $(esc(ALLOC_CACHE)) => cache $(esc(expr)) - free_busy!(cache) - res - end + local cache = $(esc(cache)) + GC.@preserve cache $(Expr(:tryfinally, try_expr, fin_expr)) end end diff --git a/test/testsuite/alloc_cache.jl b/test/testsuite/alloc_cache.jl index 6eae52243..e63ca6c2c 100644 --- a/test/testsuite/alloc_cache.jl +++ b/test/testsuite/alloc_cache.jl @@ -69,6 +69,14 @@ cached5 = AT(zeros(T, dims)) end + # exceptions shouldn't cause issues + @test_throws "Allowed exception" GPUArrays.@cached cache begin + AT(zeros(T, dims)) + error("Allowed exception") + end + # NOTE: this should remaint the last test before calling `unsafe_free!` below, + # as it caught an erroneous assertion in the original code. + # freeing all memory held by cache should free all allocations @test !GPUArrays.storage(cached1).freed @test GPUArrays.storage(cached1).cached