Skip to content

Commit 6ab280e

Browse files
Add zero and copy dispatches for PreallocationTools types
This commit adds Base.zero and Base.copy methods for all cache types: - DiffCache - FixedSizeDiffCache - LazyBufferCache - GeneralLazyBufferCache Key changes: - Modified struct definitions to allow dual_du to be Nothing - Added zero methods that create properly zeroed structures - Added copy methods that create independent deep copies - Added comprehensive tests for all functionality - Handles cases where dual_du might be nothing (when ForwardDiff isn't loaded) 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 07e3a76 commit 6ab280e

File tree

3 files changed

+170
-2
lines changed

3 files changed

+170
-2
lines changed

src/PreallocationTools.jl

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module PreallocationTools
33
using ArrayInterface, Adapt
44
using PrecompileTools
55

6-
struct FixedSizeDiffCache{T <: AbstractArray, S <: AbstractArray}
6+
struct FixedSizeDiffCache{T <: AbstractArray, S <: Union{AbstractArray, Nothing}}
77
du::T
88
dual_du::S
99
any_du::Vector{Any}
@@ -80,7 +80,7 @@ end
8080

8181
# DiffCache
8282

83-
struct DiffCache{T <: AbstractArray, S <: AbstractArray}
83+
struct DiffCache{T <: AbstractArray, S <: Union{AbstractArray, Nothing}}
8484
du::T
8585
dual_du::S
8686
any_du::Vector{Any}
@@ -302,6 +302,54 @@ function Base.resize!(dc::FixedSizeDiffCache, n::Integer)
302302
return dc
303303
end
304304

305+
# zero dispatches for PreallocationTools types
306+
function Base.zero(dc::DiffCache)
307+
dual_du_zero = dc.dual_du === nothing ? nothing : zero(dc.dual_du)
308+
DiffCache(zero(dc.du), dual_du_zero, Any[])
309+
end
310+
311+
function Base.zero(dc::FixedSizeDiffCache)
312+
dual_du_zero = dc.dual_du === nothing ? nothing : zero(dc.dual_du)
313+
FixedSizeDiffCache(zero(dc.du), dual_du_zero, Any[])
314+
end
315+
316+
function Base.zero(lbc::LazyBufferCache)
317+
LazyBufferCache(lbc.sizemap; initializer! = lbc.initializer!)
318+
end
319+
320+
function Base.zero(glbc::GeneralLazyBufferCache)
321+
GeneralLazyBufferCache(glbc.f)
322+
end
323+
324+
# copy dispatches for PreallocationTools types
325+
function Base.copy(dc::DiffCache)
326+
dual_du_copy = dc.dual_du === nothing ? nothing : copy(dc.dual_du)
327+
DiffCache(copy(dc.du), dual_du_copy, copy(dc.any_du))
328+
end
329+
330+
function Base.copy(dc::FixedSizeDiffCache)
331+
dual_du_copy = dc.dual_du === nothing ? nothing : copy(dc.dual_du)
332+
FixedSizeDiffCache(copy(dc.du), dual_du_copy, copy(dc.any_du))
333+
end
334+
335+
function Base.copy(lbc::LazyBufferCache)
336+
new_lbc = LazyBufferCache(lbc.sizemap; initializer! = lbc.initializer!)
337+
# Copy the internal buffer dictionary
338+
for (key, val) in lbc.bufs
339+
new_lbc.bufs[key] = copy(val)
340+
end
341+
new_lbc
342+
end
343+
344+
function Base.copy(glbc::GeneralLazyBufferCache)
345+
new_glbc = GeneralLazyBufferCache(glbc.f)
346+
# Copy the internal buffer dictionary
347+
for (key, val) in glbc.bufs
348+
new_glbc.bufs[key] = copy(val)
349+
end
350+
new_glbc
351+
end
352+
305353
export GeneralLazyBufferCache, FixedSizeDiffCache, DiffCache, LazyBufferCache, dualcache
306354
export get_tmp
307355

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ if GROUP == "All" || GROUP == "Core"
1919
@safetestset "DiffCache with SparseConnectivityTracer" include("sparse_connectivity_tracer.jl")
2020
@safetestset "LazyBufferCache" include("lbc.jl")
2121
@safetestset "GeneralLazyBufferCache" include("general_lbc.jl")
22+
@safetestset "Zero and Copy Dispatches" include("test_zero_copy.jl")
2223
end
2324

2425
if GROUP == "GPU"

test/test_zero_copy.jl

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
using Test, PreallocationTools, ForwardDiff
2+
3+
@testset "zero and copy dispatches" begin
4+
@testset "DiffCache" begin
5+
u = rand(10)
6+
cache = DiffCache(u, 5)
7+
8+
# Test zero
9+
zero_cache = zero(cache)
10+
@test isa(zero_cache, DiffCache)
11+
@test all(zero_cache.du .== 0)
12+
@test all(zero_cache.dual_du .== 0)
13+
@test isempty(zero_cache.any_du)
14+
15+
# Test copy
16+
copy_cache = copy(cache)
17+
@test isa(copy_cache, DiffCache)
18+
@test copy_cache.du == cache.du
19+
@test copy_cache.dual_du == cache.dual_du
20+
@test copy_cache.any_du == cache.any_du
21+
# Ensure it's a copy, not a reference
22+
copy_cache.du[1] = -999
23+
@test cache.du[1] != -999
24+
end
25+
26+
@testset "FixedSizeDiffCache" begin
27+
u = rand(10)
28+
cache = FixedSizeDiffCache(u, Val{5})
29+
30+
# Test zero
31+
zero_cache = zero(cache)
32+
@test isa(zero_cache, FixedSizeDiffCache)
33+
@test all(zero_cache.du .== 0)
34+
@test isempty(zero_cache.any_du)
35+
# Handle case where dual_du might be nothing
36+
if cache.dual_du !== nothing
37+
@test all(zero_cache.dual_du .== 0)
38+
else
39+
@test zero_cache.dual_du === nothing
40+
end
41+
42+
# Test copy
43+
copy_cache = copy(cache)
44+
@test isa(copy_cache, FixedSizeDiffCache)
45+
@test copy_cache.du == cache.du
46+
@test copy_cache.any_du == cache.any_du
47+
# Ensure it's a copy, not a reference
48+
copy_cache.du[1] = -999
49+
@test cache.du[1] != -999
50+
end
51+
52+
@testset "LazyBufferCache" begin
53+
lbc = LazyBufferCache(identity; initializer! = buf -> fill!(buf, 0.0))
54+
u = rand(10)
55+
buf = lbc[u] # Create a buffer in the cache
56+
57+
# Test zero - creates a new empty cache with same configuration
58+
zero_lbc = zero(lbc)
59+
@test isa(zero_lbc, LazyBufferCache)
60+
@test isempty(zero_lbc.bufs)
61+
@test zero_lbc.sizemap === lbc.sizemap
62+
@test zero_lbc.initializer! === lbc.initializer!
63+
64+
# Test copy
65+
copy_lbc = copy(lbc)
66+
@test isa(copy_lbc, LazyBufferCache)
67+
@test copy_lbc.sizemap === lbc.sizemap
68+
@test copy_lbc.initializer! === lbc.initializer!
69+
# Check that buffers were copied
70+
@test !isempty(copy_lbc.bufs)
71+
# Modify the copy to ensure it's independent
72+
buf_copy = copy_lbc[u]
73+
buf_copy[1] = -999
74+
@test buf[1] != -999
75+
end
76+
77+
@testset "GeneralLazyBufferCache" begin
78+
glbc = GeneralLazyBufferCache(u -> similar(u))
79+
u = rand(10)
80+
buf = glbc[u] # Create a buffer in the cache
81+
82+
# Test zero - creates a new empty cache with same function
83+
zero_glbc = zero(glbc)
84+
@test isa(zero_glbc, GeneralLazyBufferCache)
85+
@test isempty(zero_glbc.bufs)
86+
@test zero_glbc.f === glbc.f
87+
88+
# Test copy
89+
copy_glbc = copy(glbc)
90+
@test isa(copy_glbc, GeneralLazyBufferCache)
91+
@test copy_glbc.f === glbc.f
92+
# Check that buffers were copied
93+
@test !isempty(copy_glbc.bufs)
94+
# Modify the copy to ensure it's independent
95+
buf_copy = copy_glbc[u]
96+
buf_copy[1] = -999
97+
@test buf[1] != -999
98+
end
99+
100+
@testset "DiffCache with matrix" begin
101+
u = rand(5, 5)
102+
cache = DiffCache(u, 3)
103+
104+
# Test zero
105+
zero_cache = zero(cache)
106+
@test isa(zero_cache, DiffCache)
107+
@test size(zero_cache.du) == size(u)
108+
@test all(zero_cache.du .== 0)
109+
110+
# Test copy
111+
copy_cache = copy(cache)
112+
@test isa(copy_cache, DiffCache)
113+
@test size(copy_cache.du) == size(u)
114+
@test copy_cache.du == cache.du
115+
# Ensure it's a copy, not a reference
116+
copy_cache.du[1, 1] = -999
117+
@test cache.du[1, 1] != -999
118+
end
119+
end

0 commit comments

Comments
 (0)