diff --git a/src/PreallocationTools.jl b/src/PreallocationTools.jl index 4440b52..a1fbcb5 100644 --- a/src/PreallocationTools.jl +++ b/src/PreallocationTools.jl @@ -341,6 +341,61 @@ function Base.copy(glbc::GeneralLazyBufferCache) new_glbc end +# fill! dispatches for PreallocationTools types +""" + fill!(dc::DiffCache, val) + +Fill all allocated buffers in the DiffCache with the given value. +""" +function Base.fill!(dc::DiffCache, val) + fill!(dc.du, val) + fill!(dc.dual_du, val) + fill!(dc.any_du, nothing) + return dc +end + +""" + fill!(dc::FixedSizeDiffCache, val) + +Fill all allocated buffers in the FixedSizeDiffCache with the given value. +""" +function Base.fill!(dc::FixedSizeDiffCache, val) + fill!(dc.du, val) + fill!(dc.dual_du, val) + fill!(dc.any_du, nothing) + return dc +end + +""" + fill!(lbc::LazyBufferCache, val) + +Fill all allocated buffers in the LazyBufferCache with the given value. +""" +function Base.fill!(lbc::LazyBufferCache, val) + for (_, buffer) in lbc.bufs + if buffer isa AbstractArray + fill!(buffer, val) + end + end + return lbc +end + +""" + fill!(glbc::GeneralLazyBufferCache, val) + +Fill all allocated buffers in the GeneralLazyBufferCache with the given value. +""" +function Base.fill!(glbc::GeneralLazyBufferCache, val) + for (_, buffer) in glbc.bufs + if buffer isa AbstractArray + fill!(buffer, val) + elseif applicable(fill!, buffer, val) + fill!(buffer, val) + end + end + return glbc +end + export GeneralLazyBufferCache, FixedSizeDiffCache, DiffCache, LazyBufferCache, dualcache export get_tmp diff --git a/test/runtests.jl b/test/runtests.jl index c6e5e99..322eda5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,7 +19,7 @@ if GROUP == "All" || GROUP == "Core" @safetestset "DiffCache with SparseConnectivityTracer" include("sparse_connectivity_tracer.jl") @safetestset "LazyBufferCache" include("lbc.jl") @safetestset "GeneralLazyBufferCache" include("general_lbc.jl") - @safetestset "Zero and Copy Dispatches" include("test_zero_copy.jl") + @safetestset "Zero, Copy, and Fill Dispatches" include("test_zero_copy.jl") end if GROUP == "GPU" diff --git a/test/test_zero_copy.jl b/test/test_zero_copy.jl index 9e5d20a..4a1df39 100644 --- a/test/test_zero_copy.jl +++ b/test/test_zero_copy.jl @@ -112,4 +112,111 @@ using Test, PreallocationTools, ForwardDiff copy_cache.du[1, 1] = -999 @test cache.du[1, 1] != -999 end +end + +@testset "fill! dispatches" begin + @testset "DiffCache fill!" begin + u = rand(10) + cache = DiffCache(u, 5) + + # Fill with non-zero values initially + fill!(cache.du, 1.0) + fill!(cache.dual_du, 2.0) + push!(cache.any_du, 3.0) + + # Test fill! with 0 + fill!(cache, 0.0) + @test all(cache.du .== 0) + @test all(cache.dual_du .== 0) + @test all(cache.any_du .=== nothing) + + # Test fill! with other values + fill!(cache, 5.0) + @test all(cache.du .== 5.0) + @test all(cache.dual_du .== 5.0) + end + + @testset "FixedSizeDiffCache fill!" begin + u = rand(10) + cache = FixedSizeDiffCache(u, Val{5}) + + # Fill with non-zero values initially + fill!(cache.du, 1.0) + fill!(cache.dual_du, 2.0) + push!(cache.any_du, 3.0) + + # Test fill! with 0 + fill!(cache, 0.0) + @test all(cache.du .== 0) + @test all(cache.dual_du .== 0) + @test all(cache.any_du .=== nothing) + + # Test fill! with other values + fill!(cache, 3.0) + @test all(cache.du .== 3.0) + @test all(cache.dual_du .== 3.0) + end + + @testset "LazyBufferCache fill!" begin + lbc = LazyBufferCache(identity) + u = rand(10) + v = rand(5, 5) + + # Create and fill buffers + buf1 = lbc[u] + fill!(buf1, 1.0) + buf2 = lbc[v] + fill!(buf2, 2.0) + + # Test fill! with 0 + fill!(lbc, 0.0) + @test all(buf1 .== 0) + @test all(buf2 .== 0) + # Check that the buffers are still in the cache + @test lbc[u] === buf1 + @test lbc[v] === buf2 + + # Test fill! with other values + fill!(lbc, 7.0) + @test all(buf1 .== 7.0) + @test all(buf2 .== 7.0) + end + + @testset "GeneralLazyBufferCache fill!" begin + glbc = GeneralLazyBufferCache(u -> similar(u)) + u = rand(10) + + # Create and fill buffer + buf = glbc[u] + fill!(buf, 1.0) + + # Test fill! with 0 + fill!(glbc, 0.0) + @test all(buf .== 0) + # Check that the buffer is still in the cache + @test glbc[u] === buf + + # Test fill! with other values + fill!(glbc, -2.5) + @test all(buf .== -2.5) + end + + @testset "LazyBufferCache fill! with mixed types" begin + lbc = LazyBufferCache(identity) + u_float = rand(Float64, 10) + u_int = rand(Int, 5) + + # Create and fill buffers + buf_float = lbc[u_float] + fill!(buf_float, 1.5) + buf_int = lbc[u_int] + fill!(buf_int, 7) + + # Test fill! with 0 + fill!(lbc, 0) + @test all(buf_float .== 0.0) + @test all(buf_int .== 0) + @test eltype(buf_float) == Float64 + @test eltype(buf_int) == Int + end end \ No newline at end of file