Skip to content

Commit dfc60b3

Browse files
Merge pull request #142 from ChrisRackauckas-Claude/add-zero-function
Add fill! overloads for cache types
2 parents deb294a + 07a5510 commit dfc60b3

File tree

3 files changed

+163
-1
lines changed

3 files changed

+163
-1
lines changed

src/PreallocationTools.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,61 @@ function Base.copy(glbc::GeneralLazyBufferCache)
341341
new_glbc
342342
end
343343

344+
# fill! dispatches for PreallocationTools types
345+
"""
346+
fill!(dc::DiffCache, val)
347+
348+
Fill all allocated buffers in the DiffCache with the given value.
349+
"""
350+
function Base.fill!(dc::DiffCache, val)
351+
fill!(dc.du, val)
352+
fill!(dc.dual_du, val)
353+
fill!(dc.any_du, nothing)
354+
return dc
355+
end
356+
357+
"""
358+
fill!(dc::FixedSizeDiffCache, val)
359+
360+
Fill all allocated buffers in the FixedSizeDiffCache with the given value.
361+
"""
362+
function Base.fill!(dc::FixedSizeDiffCache, val)
363+
fill!(dc.du, val)
364+
fill!(dc.dual_du, val)
365+
fill!(dc.any_du, nothing)
366+
return dc
367+
end
368+
369+
"""
370+
fill!(lbc::LazyBufferCache, val)
371+
372+
Fill all allocated buffers in the LazyBufferCache with the given value.
373+
"""
374+
function Base.fill!(lbc::LazyBufferCache, val)
375+
for (_, buffer) in lbc.bufs
376+
if buffer isa AbstractArray
377+
fill!(buffer, val)
378+
end
379+
end
380+
return lbc
381+
end
382+
383+
"""
384+
fill!(glbc::GeneralLazyBufferCache, val)
385+
386+
Fill all allocated buffers in the GeneralLazyBufferCache with the given value.
387+
"""
388+
function Base.fill!(glbc::GeneralLazyBufferCache, val)
389+
for (_, buffer) in glbc.bufs
390+
if buffer isa AbstractArray
391+
fill!(buffer, val)
392+
elseif applicable(fill!, buffer, val)
393+
fill!(buffer, val)
394+
end
395+
end
396+
return glbc
397+
end
398+
344399
export GeneralLazyBufferCache, FixedSizeDiffCache, DiffCache, LazyBufferCache, dualcache
345400
export get_tmp
346401

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ if GROUP == "All" || GROUP == "Core"
1919
@safetestset "DiffCache with SparseConnectivityTracer" include("sparse_connectivity_tracer.jl")
2020
@safetestset "LazyBufferCache" include("lbc.jl")
2121
@safetestset "GeneralLazyBufferCache" include("general_lbc.jl")
22-
@safetestset "Zero and Copy Dispatches" include("test_zero_copy.jl")
22+
@safetestset "Zero, Copy, and Fill Dispatches" include("test_zero_copy.jl")
2323
end
2424

2525
if GROUP == "GPU"

test/test_zero_copy.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,111 @@ using Test, PreallocationTools, ForwardDiff
112112
copy_cache.du[1, 1] = -999
113113
@test cache.du[1, 1] != -999
114114
end
115+
end
116+
117+
@testset "fill! dispatches" begin
118+
@testset "DiffCache fill!" begin
119+
u = rand(10)
120+
cache = DiffCache(u, 5)
121+
122+
# Fill with non-zero values initially
123+
fill!(cache.du, 1.0)
124+
fill!(cache.dual_du, 2.0)
125+
push!(cache.any_du, 3.0)
126+
127+
# Test fill! with 0
128+
fill!(cache, 0.0)
129+
@test all(cache.du .== 0)
130+
@test all(cache.dual_du .== 0)
131+
@test all(cache.any_du .=== nothing)
132+
133+
# Test fill! with other values
134+
fill!(cache, 5.0)
135+
@test all(cache.du .== 5.0)
136+
@test all(cache.dual_du .== 5.0)
137+
end
138+
139+
@testset "FixedSizeDiffCache fill!" begin
140+
u = rand(10)
141+
cache = FixedSizeDiffCache(u, Val{5})
142+
143+
# Fill with non-zero values initially
144+
fill!(cache.du, 1.0)
145+
fill!(cache.dual_du, 2.0)
146+
push!(cache.any_du, 3.0)
147+
148+
# Test fill! with 0
149+
fill!(cache, 0.0)
150+
@test all(cache.du .== 0)
151+
@test all(cache.dual_du .== 0)
152+
@test all(cache.any_du .=== nothing)
153+
154+
# Test fill! with other values
155+
fill!(cache, 3.0)
156+
@test all(cache.du .== 3.0)
157+
@test all(cache.dual_du .== 3.0)
158+
end
159+
160+
@testset "LazyBufferCache fill!" begin
161+
lbc = LazyBufferCache(identity)
162+
u = rand(10)
163+
v = rand(5, 5)
164+
165+
# Create and fill buffers
166+
buf1 = lbc[u]
167+
fill!(buf1, 1.0)
168+
buf2 = lbc[v]
169+
fill!(buf2, 2.0)
170+
171+
# Test fill! with 0
172+
fill!(lbc, 0.0)
173+
@test all(buf1 .== 0)
174+
@test all(buf2 .== 0)
175+
# Check that the buffers are still in the cache
176+
@test lbc[u] === buf1
177+
@test lbc[v] === buf2
178+
179+
# Test fill! with other values
180+
fill!(lbc, 7.0)
181+
@test all(buf1 .== 7.0)
182+
@test all(buf2 .== 7.0)
183+
end
184+
185+
@testset "GeneralLazyBufferCache fill!" begin
186+
glbc = GeneralLazyBufferCache(u -> similar(u))
187+
u = rand(10)
188+
189+
# Create and fill buffer
190+
buf = glbc[u]
191+
fill!(buf, 1.0)
192+
193+
# Test fill! with 0
194+
fill!(glbc, 0.0)
195+
@test all(buf .== 0)
196+
# Check that the buffer is still in the cache
197+
@test glbc[u] === buf
198+
199+
# Test fill! with other values
200+
fill!(glbc, -2.5)
201+
@test all(buf .== -2.5)
202+
end
203+
204+
@testset "LazyBufferCache fill! with mixed types" begin
205+
lbc = LazyBufferCache(identity)
206+
u_float = rand(Float64, 10)
207+
u_int = rand(Int, 5)
208+
209+
# Create and fill buffers
210+
buf_float = lbc[u_float]
211+
fill!(buf_float, 1.5)
212+
buf_int = lbc[u_int]
213+
fill!(buf_int, 7)
214+
215+
# Test fill! with 0
216+
fill!(lbc, 0)
217+
@test all(buf_float .== 0.0)
218+
@test all(buf_int .== 0)
219+
@test eltype(buf_float) == Float64
220+
@test eltype(buf_int) == Int
221+
end
115222
end

0 commit comments

Comments
 (0)