Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/PreallocationTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,61 @@ function Base.copy(glbc::GeneralLazyBufferCache)
new_glbc
end

# fill! dispatches for PreallocationTools types
"""
fill!(dc::DiffCache, val)

Fill all allocated buffers in the DiffCache with the given value.
"""
function Base.fill!(dc::DiffCache, val)
fill!(dc.du, val)
fill!(dc.dual_du, val)
fill!(dc.any_du, nothing)
return dc
end

"""
fill!(dc::FixedSizeDiffCache, val)

Fill all allocated buffers in the FixedSizeDiffCache with the given value.
"""
function Base.fill!(dc::FixedSizeDiffCache, val)
fill!(dc.du, val)
fill!(dc.dual_du, val)
fill!(dc.any_du, nothing)
return dc
end

"""
fill!(lbc::LazyBufferCache, val)

Fill all allocated buffers in the LazyBufferCache with the given value.
"""
function Base.fill!(lbc::LazyBufferCache, val)
for (_, buffer) in lbc.bufs
if buffer isa AbstractArray
fill!(buffer, val)
end
end
return lbc
end

"""
fill!(glbc::GeneralLazyBufferCache, val)

Fill all allocated buffers in the GeneralLazyBufferCache with the given value.
"""
function Base.fill!(glbc::GeneralLazyBufferCache, val)
for (_, buffer) in glbc.bufs
if buffer isa AbstractArray
fill!(buffer, val)
elseif applicable(fill!, buffer, val)
fill!(buffer, val)
end
end
return glbc
end

export GeneralLazyBufferCache, FixedSizeDiffCache, DiffCache, LazyBufferCache, dualcache
export get_tmp

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "DiffCache with SparseConnectivityTracer" include("sparse_connectivity_tracer.jl")
@safetestset "LazyBufferCache" include("lbc.jl")
@safetestset "GeneralLazyBufferCache" include("general_lbc.jl")
@safetestset "Zero and Copy Dispatches" include("test_zero_copy.jl")
@safetestset "Zero, Copy, and Fill Dispatches" include("test_zero_copy.jl")
end

if GROUP == "GPU"
Expand Down
107 changes: 107 additions & 0 deletions test/test_zero_copy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,111 @@ using Test, PreallocationTools, ForwardDiff
copy_cache.du[1, 1] = -999
@test cache.du[1, 1] != -999
end
end

@testset "fill! dispatches" begin
@testset "DiffCache fill!" begin
u = rand(10)
cache = DiffCache(u, 5)

# Fill with non-zero values initially
fill!(cache.du, 1.0)
fill!(cache.dual_du, 2.0)
push!(cache.any_du, 3.0)

# Test fill! with 0
fill!(cache, 0.0)
@test all(cache.du .== 0)
@test all(cache.dual_du .== 0)
@test all(cache.any_du .=== nothing)

# Test fill! with other values
fill!(cache, 5.0)
@test all(cache.du .== 5.0)
@test all(cache.dual_du .== 5.0)
end

@testset "FixedSizeDiffCache fill!" begin
u = rand(10)
cache = FixedSizeDiffCache(u, Val{5})

# Fill with non-zero values initially
fill!(cache.du, 1.0)
fill!(cache.dual_du, 2.0)
push!(cache.any_du, 3.0)

# Test fill! with 0
fill!(cache, 0.0)
@test all(cache.du .== 0)
@test all(cache.dual_du .== 0)
@test all(cache.any_du .=== nothing)

# Test fill! with other values
fill!(cache, 3.0)
@test all(cache.du .== 3.0)
@test all(cache.dual_du .== 3.0)
end

@testset "LazyBufferCache fill!" begin
lbc = LazyBufferCache(identity)
u = rand(10)
v = rand(5, 5)

# Create and fill buffers
buf1 = lbc[u]
fill!(buf1, 1.0)
buf2 = lbc[v]
fill!(buf2, 2.0)

# Test fill! with 0
fill!(lbc, 0.0)
@test all(buf1 .== 0)
@test all(buf2 .== 0)
# Check that the buffers are still in the cache
@test lbc[u] === buf1
@test lbc[v] === buf2

# Test fill! with other values
fill!(lbc, 7.0)
@test all(buf1 .== 7.0)
@test all(buf2 .== 7.0)
end

@testset "GeneralLazyBufferCache fill!" begin
glbc = GeneralLazyBufferCache(u -> similar(u))
u = rand(10)

# Create and fill buffer
buf = glbc[u]
fill!(buf, 1.0)

# Test fill! with 0
fill!(glbc, 0.0)
@test all(buf .== 0)
# Check that the buffer is still in the cache
@test glbc[u] === buf

# Test fill! with other values
fill!(glbc, -2.5)
@test all(buf .== -2.5)
end

@testset "LazyBufferCache fill! with mixed types" begin
lbc = LazyBufferCache(identity)
u_float = rand(Float64, 10)
u_int = rand(Int, 5)

# Create and fill buffers
buf_float = lbc[u_float]
fill!(buf_float, 1.5)
buf_int = lbc[u_int]
fill!(buf_int, 7)

# Test fill! with 0
fill!(lbc, 0)
@test all(buf_float .== 0.0)
@test all(buf_int .== 0)
@test eltype(buf_float) == Float64
@test eltype(buf_int) == Int
end
end
Loading