diff --git a/src/PreallocationTools.jl b/src/PreallocationTools.jl index e46b416..6619f4d 100644 --- a/src/PreallocationTools.jl +++ b/src/PreallocationTools.jl @@ -302,6 +302,50 @@ function Base.resize!(dc::FixedSizeDiffCache, n::Integer) return dc end +# zero dispatches for PreallocationTools types +function Base.zero(dc::DiffCache) + DiffCache(zero(dc.du), zero(dc.dual_du), Any[]) +end + +function Base.zero(dc::FixedSizeDiffCache) + FixedSizeDiffCache(zero(dc.du), zero(dc.dual_du), Any[]) +end + +function Base.zero(lbc::LazyBufferCache) + LazyBufferCache(lbc.sizemap; initializer! = lbc.initializer!) +end + +function Base.zero(glbc::GeneralLazyBufferCache) + GeneralLazyBufferCache(glbc.f) +end + +# copy dispatches for PreallocationTools types +function Base.copy(dc::DiffCache) + DiffCache(copy(dc.du), copy(dc.dual_du), copy(dc.any_du)) +end + +function Base.copy(dc::FixedSizeDiffCache) + FixedSizeDiffCache(copy(dc.du), copy(dc.dual_du), copy(dc.any_du)) +end + +function Base.copy(lbc::LazyBufferCache) + new_lbc = LazyBufferCache(lbc.sizemap; initializer! = lbc.initializer!) + # Copy the internal buffer dictionary + for (key, val) in lbc.bufs + new_lbc.bufs[key] = copy(val) + end + new_lbc +end + +function Base.copy(glbc::GeneralLazyBufferCache) + new_glbc = GeneralLazyBufferCache(glbc.f) + # Copy the internal buffer dictionary + for (key, val) in glbc.bufs + new_glbc.bufs[key] = copy(val) + end + new_glbc +end + export GeneralLazyBufferCache, FixedSizeDiffCache, DiffCache, LazyBufferCache, dualcache export get_tmp diff --git a/test/runtests.jl b/test/runtests.jl index 0ae51df..c6e5e99 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,6 +19,7 @@ if GROUP == "All" || GROUP == "Core" @safetestset "DiffCache with SparseConnectivityTracer" include("sparse_connectivity_tracer.jl") @safetestset "LazyBufferCache" include("lbc.jl") @safetestset "GeneralLazyBufferCache" include("general_lbc.jl") + @safetestset "Zero and Copy Dispatches" include("test_zero_copy.jl") end if GROUP == "GPU" diff --git a/test/test_zero_copy.jl b/test/test_zero_copy.jl new file mode 100644 index 0000000..9e5d20a --- /dev/null +++ b/test/test_zero_copy.jl @@ -0,0 +1,115 @@ +using Test, PreallocationTools, ForwardDiff + +@testset "zero and copy dispatches" begin + @testset "DiffCache" begin + u = rand(10) + cache = DiffCache(u, 5) + + # Test zero + zero_cache = zero(cache) + @test isa(zero_cache, DiffCache) + @test all(zero_cache.du .== 0) + @test all(zero_cache.dual_du .== 0) + @test isempty(zero_cache.any_du) + + # Test copy + copy_cache = copy(cache) + @test isa(copy_cache, DiffCache) + @test copy_cache.du == cache.du + @test copy_cache.dual_du == cache.dual_du + @test copy_cache.any_du == cache.any_du + # Ensure it's a copy, not a reference + copy_cache.du[1] = -999 + @test cache.du[1] != -999 + end + + @testset "FixedSizeDiffCache" begin + u = rand(10) + cache = FixedSizeDiffCache(u, Val{5}) + + # Test zero + zero_cache = zero(cache) + @test isa(zero_cache, FixedSizeDiffCache) + @test all(zero_cache.du .== 0) + @test all(zero_cache.dual_du .== 0) + @test isempty(zero_cache.any_du) + + # Test copy + copy_cache = copy(cache) + @test isa(copy_cache, FixedSizeDiffCache) + @test copy_cache.du == cache.du + @test copy_cache.dual_du == cache.dual_du + @test copy_cache.any_du == cache.any_du + # Ensure it's a copy, not a reference + copy_cache.du[1] = -999 + @test cache.du[1] != -999 + end + + @testset "LazyBufferCache" begin + lbc = LazyBufferCache(identity; initializer! = buf -> fill!(buf, 0.0)) + u = rand(10) + buf = lbc[u] # Create a buffer in the cache + + # Test zero - creates a new empty cache with same configuration + zero_lbc = zero(lbc) + @test isa(zero_lbc, LazyBufferCache) + @test isempty(zero_lbc.bufs) + @test zero_lbc.sizemap === lbc.sizemap + @test zero_lbc.initializer! === lbc.initializer! + + # Test copy + copy_lbc = copy(lbc) + @test isa(copy_lbc, LazyBufferCache) + @test copy_lbc.sizemap === lbc.sizemap + @test copy_lbc.initializer! === lbc.initializer! + # Check that buffers were copied + @test !isempty(copy_lbc.bufs) + # Modify the copy to ensure it's independent + buf_copy = copy_lbc[u] + buf_copy[1] = -999 + @test buf[1] != -999 + end + + @testset "GeneralLazyBufferCache" begin + glbc = GeneralLazyBufferCache(u -> similar(u)) + u = rand(10) + buf = glbc[u] # Create a buffer in the cache + + # Test zero - creates a new empty cache with same function + zero_glbc = zero(glbc) + @test isa(zero_glbc, GeneralLazyBufferCache) + @test isempty(zero_glbc.bufs) + @test zero_glbc.f === glbc.f + + # Test copy + copy_glbc = copy(glbc) + @test isa(copy_glbc, GeneralLazyBufferCache) + @test copy_glbc.f === glbc.f + # Check that buffers were copied + @test !isempty(copy_glbc.bufs) + # Modify the copy to ensure it's independent + buf_copy = copy_glbc[u] + buf_copy[1] = -999 + @test buf[1] != -999 + end + + @testset "DiffCache with matrix" begin + u = rand(5, 5) + cache = DiffCache(u, 3) + + # Test zero + zero_cache = zero(cache) + @test isa(zero_cache, DiffCache) + @test size(zero_cache.du) == size(u) + @test all(zero_cache.du .== 0) + + # Test copy + copy_cache = copy(cache) + @test isa(copy_cache, DiffCache) + @test size(copy_cache.du) == size(u) + @test copy_cache.du == cache.du + # Ensure it's a copy, not a reference + copy_cache.du[1, 1] = -999 + @test cache.du[1, 1] != -999 + end +end \ No newline at end of file