Skip to content

Commit fd3df39

Browse files
Merge pull request #138 from ChrisRackauckas-Claude/add-zero-copy-dispatches
Add zero and copy dispatches for PreallocationTools types
2 parents 07e3a76 + 6630f60 commit fd3df39

File tree

3 files changed

+160
-0
lines changed

3 files changed

+160
-0
lines changed

src/PreallocationTools.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,50 @@ function Base.resize!(dc::FixedSizeDiffCache, n::Integer)
302302
return dc
303303
end
304304

305+
# zero dispatches for PreallocationTools types
306+
function Base.zero(dc::DiffCache)
307+
DiffCache(zero(dc.du), zero(dc.dual_du), Any[])
308+
end
309+
310+
function Base.zero(dc::FixedSizeDiffCache)
311+
FixedSizeDiffCache(zero(dc.du), zero(dc.dual_du), Any[])
312+
end
313+
314+
function Base.zero(lbc::LazyBufferCache)
315+
LazyBufferCache(lbc.sizemap; initializer! = lbc.initializer!)
316+
end
317+
318+
function Base.zero(glbc::GeneralLazyBufferCache)
319+
GeneralLazyBufferCache(glbc.f)
320+
end
321+
322+
# copy dispatches for PreallocationTools types
323+
function Base.copy(dc::DiffCache)
324+
DiffCache(copy(dc.du), copy(dc.dual_du), copy(dc.any_du))
325+
end
326+
327+
function Base.copy(dc::FixedSizeDiffCache)
328+
FixedSizeDiffCache(copy(dc.du), copy(dc.dual_du), copy(dc.any_du))
329+
end
330+
331+
function Base.copy(lbc::LazyBufferCache)
332+
new_lbc = LazyBufferCache(lbc.sizemap; initializer! = lbc.initializer!)
333+
# Copy the internal buffer dictionary
334+
for (key, val) in lbc.bufs
335+
new_lbc.bufs[key] = copy(val)
336+
end
337+
new_lbc
338+
end
339+
340+
function Base.copy(glbc::GeneralLazyBufferCache)
341+
new_glbc = GeneralLazyBufferCache(glbc.f)
342+
# Copy the internal buffer dictionary
343+
for (key, val) in glbc.bufs
344+
new_glbc.bufs[key] = copy(val)
345+
end
346+
new_glbc
347+
end
348+
305349
export GeneralLazyBufferCache, FixedSizeDiffCache, DiffCache, LazyBufferCache, dualcache
306350
export get_tmp
307351

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ if GROUP == "All" || GROUP == "Core"
1919
@safetestset "DiffCache with SparseConnectivityTracer" include("sparse_connectivity_tracer.jl")
2020
@safetestset "LazyBufferCache" include("lbc.jl")
2121
@safetestset "GeneralLazyBufferCache" include("general_lbc.jl")
22+
@safetestset "Zero and Copy Dispatches" include("test_zero_copy.jl")
2223
end
2324

2425
if GROUP == "GPU"

test/test_zero_copy.jl

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
using Test, PreallocationTools, ForwardDiff
2+
3+
@testset "zero and copy dispatches" begin
4+
@testset "DiffCache" begin
5+
u = rand(10)
6+
cache = DiffCache(u, 5)
7+
8+
# Test zero
9+
zero_cache = zero(cache)
10+
@test isa(zero_cache, DiffCache)
11+
@test all(zero_cache.du .== 0)
12+
@test all(zero_cache.dual_du .== 0)
13+
@test isempty(zero_cache.any_du)
14+
15+
# Test copy
16+
copy_cache = copy(cache)
17+
@test isa(copy_cache, DiffCache)
18+
@test copy_cache.du == cache.du
19+
@test copy_cache.dual_du == cache.dual_du
20+
@test copy_cache.any_du == cache.any_du
21+
# Ensure it's a copy, not a reference
22+
copy_cache.du[1] = -999
23+
@test cache.du[1] != -999
24+
end
25+
26+
@testset "FixedSizeDiffCache" begin
27+
u = rand(10)
28+
cache = FixedSizeDiffCache(u, Val{5})
29+
30+
# Test zero
31+
zero_cache = zero(cache)
32+
@test isa(zero_cache, FixedSizeDiffCache)
33+
@test all(zero_cache.du .== 0)
34+
@test all(zero_cache.dual_du .== 0)
35+
@test isempty(zero_cache.any_du)
36+
37+
# Test copy
38+
copy_cache = copy(cache)
39+
@test isa(copy_cache, FixedSizeDiffCache)
40+
@test copy_cache.du == cache.du
41+
@test copy_cache.dual_du == cache.dual_du
42+
@test copy_cache.any_du == cache.any_du
43+
# Ensure it's a copy, not a reference
44+
copy_cache.du[1] = -999
45+
@test cache.du[1] != -999
46+
end
47+
48+
@testset "LazyBufferCache" begin
49+
lbc = LazyBufferCache(identity; initializer! = buf -> fill!(buf, 0.0))
50+
u = rand(10)
51+
buf = lbc[u] # Create a buffer in the cache
52+
53+
# Test zero - creates a new empty cache with same configuration
54+
zero_lbc = zero(lbc)
55+
@test isa(zero_lbc, LazyBufferCache)
56+
@test isempty(zero_lbc.bufs)
57+
@test zero_lbc.sizemap === lbc.sizemap
58+
@test zero_lbc.initializer! === lbc.initializer!
59+
60+
# Test copy
61+
copy_lbc = copy(lbc)
62+
@test isa(copy_lbc, LazyBufferCache)
63+
@test copy_lbc.sizemap === lbc.sizemap
64+
@test copy_lbc.initializer! === lbc.initializer!
65+
# Check that buffers were copied
66+
@test !isempty(copy_lbc.bufs)
67+
# Modify the copy to ensure it's independent
68+
buf_copy = copy_lbc[u]
69+
buf_copy[1] = -999
70+
@test buf[1] != -999
71+
end
72+
73+
@testset "GeneralLazyBufferCache" begin
74+
glbc = GeneralLazyBufferCache(u -> similar(u))
75+
u = rand(10)
76+
buf = glbc[u] # Create a buffer in the cache
77+
78+
# Test zero - creates a new empty cache with same function
79+
zero_glbc = zero(glbc)
80+
@test isa(zero_glbc, GeneralLazyBufferCache)
81+
@test isempty(zero_glbc.bufs)
82+
@test zero_glbc.f === glbc.f
83+
84+
# Test copy
85+
copy_glbc = copy(glbc)
86+
@test isa(copy_glbc, GeneralLazyBufferCache)
87+
@test copy_glbc.f === glbc.f
88+
# Check that buffers were copied
89+
@test !isempty(copy_glbc.bufs)
90+
# Modify the copy to ensure it's independent
91+
buf_copy = copy_glbc[u]
92+
buf_copy[1] = -999
93+
@test buf[1] != -999
94+
end
95+
96+
@testset "DiffCache with matrix" begin
97+
u = rand(5, 5)
98+
cache = DiffCache(u, 3)
99+
100+
# Test zero
101+
zero_cache = zero(cache)
102+
@test isa(zero_cache, DiffCache)
103+
@test size(zero_cache.du) == size(u)
104+
@test all(zero_cache.du .== 0)
105+
106+
# Test copy
107+
copy_cache = copy(cache)
108+
@test isa(copy_cache, DiffCache)
109+
@test size(copy_cache.du) == size(u)
110+
@test copy_cache.du == cache.du
111+
# Ensure it's a copy, not a reference
112+
copy_cache.du[1, 1] = -999
113+
@test cache.du[1, 1] != -999
114+
end
115+
end

0 commit comments

Comments
 (0)