Skip to content

Commit 1c84df5

Browse files
Add fill! overloads for cache types
This adds fill! overloads for DiffCache, FixedSizeDiffCache, LazyBufferCache, and GeneralLazyBufferCache that fill all allocated buffers with a given value. This is needed to ensure proper initialization of caches in split ODE methods to avoid convergence issues. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent deb294a commit 1c84df5

File tree

3 files changed

+164
-1
lines changed

3 files changed

+164
-1
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ version = "0.4.33"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
910
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
11+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1012

1113
[weakdeps]
12-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1314
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1415
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
1516

src/PreallocationTools.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,61 @@ function Base.copy(glbc::GeneralLazyBufferCache)
341341
new_glbc
342342
end
343343

344+
# fill! dispatches for PreallocationTools types
345+
"""
346+
fill!(dc::DiffCache, val)
347+
348+
Fill all allocated buffers in the DiffCache with the given value.
349+
"""
350+
function Base.fill!(dc::DiffCache, val)
351+
fill!(dc.du, val)
352+
fill!(dc.dual_du, val)
353+
fill!(dc.any_du, nothing)
354+
return dc
355+
end
356+
357+
"""
358+
fill!(dc::FixedSizeDiffCache, val)
359+
360+
Fill all allocated buffers in the FixedSizeDiffCache with the given value.
361+
"""
362+
function Base.fill!(dc::FixedSizeDiffCache, val)
363+
fill!(dc.du, val)
364+
fill!(dc.dual_du, val)
365+
fill!(dc.any_du, nothing)
366+
return dc
367+
end
368+
369+
"""
370+
fill!(lbc::LazyBufferCache, val)
371+
372+
Fill all allocated buffers in the LazyBufferCache with the given value.
373+
"""
374+
function Base.fill!(lbc::LazyBufferCache, val)
375+
for (_, buffer) in lbc.bufs
376+
if buffer isa AbstractArray
377+
fill!(buffer, val)
378+
end
379+
end
380+
return lbc
381+
end
382+
383+
"""
384+
fill!(glbc::GeneralLazyBufferCache, val)
385+
386+
Fill all allocated buffers in the GeneralLazyBufferCache with the given value.
387+
"""
388+
function Base.fill!(glbc::GeneralLazyBufferCache, val)
389+
for (_, buffer) in glbc.bufs
390+
if buffer isa AbstractArray
391+
fill!(buffer, val)
392+
elseif applicable(fill!, buffer, val)
393+
fill!(buffer, val)
394+
end
395+
end
396+
return glbc
397+
end
398+
344399
export GeneralLazyBufferCache, FixedSizeDiffCache, DiffCache, LazyBufferCache, dualcache
345400
export get_tmp
346401

test/test_zero_copy.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,111 @@ using Test, PreallocationTools, ForwardDiff
112112
copy_cache.du[1, 1] = -999
113113
@test cache.du[1, 1] != -999
114114
end
115+
end
116+
117+
@testset "fill! dispatches" begin
118+
@testset "DiffCache fill!" begin
119+
u = rand(10)
120+
cache = DiffCache(u, 5)
121+
122+
# Fill with non-zero values initially
123+
fill!(cache.du, 1.0)
124+
fill!(cache.dual_du, 2.0)
125+
push!(cache.any_du, 3.0)
126+
127+
# Test fill! with 0
128+
fill!(cache, 0.0)
129+
@test all(cache.du .== 0)
130+
@test all(cache.dual_du .== 0)
131+
@test all(cache.any_du .=== nothing)
132+
133+
# Test fill! with other values
134+
fill!(cache, 5.0)
135+
@test all(cache.du .== 5.0)
136+
@test all(cache.dual_du .== 5.0)
137+
end
138+
139+
@testset "FixedSizeDiffCache fill!" begin
140+
u = rand(10)
141+
cache = FixedSizeDiffCache(u, Val{5})
142+
143+
# Fill with non-zero values initially
144+
fill!(cache.du, 1.0)
145+
fill!(cache.dual_du, 2.0)
146+
push!(cache.any_du, 3.0)
147+
148+
# Test fill! with 0
149+
fill!(cache, 0.0)
150+
@test all(cache.du .== 0)
151+
@test all(cache.dual_du .== 0)
152+
@test all(cache.any_du .=== nothing)
153+
154+
# Test fill! with other values
155+
fill!(cache, 3.0)
156+
@test all(cache.du .== 3.0)
157+
@test all(cache.dual_du .== 3.0)
158+
end
159+
160+
@testset "LazyBufferCache fill!" begin
161+
lbc = LazyBufferCache(identity)
162+
u = rand(10)
163+
v = rand(5, 5)
164+
165+
# Create and fill buffers
166+
buf1 = lbc[u]
167+
fill!(buf1, 1.0)
168+
buf2 = lbc[v]
169+
fill!(buf2, 2.0)
170+
171+
# Test fill! with 0
172+
fill!(lbc, 0.0)
173+
@test all(buf1 .== 0)
174+
@test all(buf2 .== 0)
175+
# Check that the buffers are still in the cache
176+
@test lbc[u] === buf1
177+
@test lbc[v] === buf2
178+
179+
# Test fill! with other values
180+
fill!(lbc, 7.0)
181+
@test all(buf1 .== 7.0)
182+
@test all(buf2 .== 7.0)
183+
end
184+
185+
@testset "GeneralLazyBufferCache fill!" begin
186+
glbc = GeneralLazyBufferCache(u -> similar(u))
187+
u = rand(10)
188+
189+
# Create and fill buffer
190+
buf = glbc[u]
191+
fill!(buf, 1.0)
192+
193+
# Test fill! with 0
194+
fill!(glbc, 0.0)
195+
@test all(buf .== 0)
196+
# Check that the buffer is still in the cache
197+
@test glbc[u] === buf
198+
199+
# Test fill! with other values
200+
fill!(glbc, -2.5)
201+
@test all(buf .== -2.5)
202+
end
203+
204+
@testset "LazyBufferCache fill! with mixed types" begin
205+
lbc = LazyBufferCache(identity)
206+
u_float = rand(Float64, 10)
207+
u_int = rand(Int, 5)
208+
209+
# Create and fill buffers
210+
buf_float = lbc[u_float]
211+
fill!(buf_float, 1.5)
212+
buf_int = lbc[u_int]
213+
fill!(buf_int, 7)
214+
215+
# Test fill! with 0
216+
fill!(lbc, 0)
217+
@test all(buf_float .== 0.0)
218+
@test all(buf_int .== 0)
219+
@test eltype(buf_float) == Float64
220+
@test eltype(buf_int) == Int
221+
end
115222
end

0 commit comments

Comments
 (0)