Skip to content

Commit 3eb3884

Browse files
Merge pull request #135 from ChrisRackauckas-Claude/resize-dispatch
Add resize! dispatch for DiffCache and FixedSizeDiffCache
2 parents f761381 + 47a6f26 commit 3eb3884

File tree

2 files changed

+81
-1
lines changed

2 files changed

+81
-1
lines changed

src/PreallocationTools.jl

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dualarraycreator(args...) = nothing
1414

1515
function FixedSizeDiffCache(u::AbstractArray{T}, siz,
1616
::Type{Val{chunk_size}}) where {T, chunk_size}
17-
x = dualarraycreator(u, siz, Val{chunk_size})
17+
x = dualarraycreator(u, siz, Val{chunk_size})
1818
xany = Any[]
1919
FixedSizeDiffCache(deepcopy(u), x, xany)
2020
end
@@ -233,6 +233,44 @@ function get_tmp(b::GeneralLazyBufferCache, u::T) where {T}
233233
end
234234
Base.getindex(b::GeneralLazyBufferCache, u::T) where {T} = get_tmp(b, u)
235235

236+
# resize! methods for PreallocationTools types
237+
# Note: resize! only works for 1D arrays (vectors)
238+
function Base.resize!(dc::DiffCache, n::Integer)
239+
# Only resize if the array is a vector
240+
if dc.du isa AbstractVector
241+
resize!(dc.du, n)
242+
else
243+
throw(ArgumentError("resize! is only supported for DiffCache with vector arrays, got $(typeof(dc.du))"))
244+
end
245+
# dual_du is often pre-allocated for ForwardDiff dual numbers,
246+
# and may need special handling based on chunk size
247+
# Only resize if it's a vector
248+
if dc.dual_du isa AbstractVector
249+
resize!(dc.dual_du, n)
250+
end
251+
# Always resize the any_du cache
252+
resize!(dc.any_du, n)
253+
return dc
254+
end
255+
256+
function Base.resize!(dc::FixedSizeDiffCache, n::Integer)
257+
# Only resize if the array is a vector
258+
if dc.du isa AbstractVector
259+
resize!(dc.du, n)
260+
else
261+
throw(ArgumentError("resize! is only supported for FixedSizeDiffCache with vector arrays, got $(typeof(dc.du))"))
262+
end
263+
# dual_du is often pre-allocated for ForwardDiff dual numbers,
264+
# and may need special handling based on chunk size
265+
# Only resize if it's a vector
266+
if dc.dual_du isa AbstractVector
267+
resize!(dc.dual_du, n)
268+
end
269+
# Always resize the any_du cache
270+
resize!(dc.any_du, n)
271+
return dc
272+
end
273+
236274
export GeneralLazyBufferCache, FixedSizeDiffCache, DiffCache, LazyBufferCache, dualcache
237275
export get_tmp
238276

test/core_resizing.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,45 @@ _du = DiffCache(du)
5252
f = A -> loss(_du, u, A, 0.0)
5353
analyticalsolution = [3.0 0; 0 0]
5454
@test ForwardDiff.gradient(f, A) analyticalsolution
55+
56+
# Test resize! functionality for DiffCache
57+
@testset "resize! for DiffCache" begin
58+
u = rand(10)
59+
dc = DiffCache(u)
60+
61+
# Initial size
62+
@test length(dc.du) == 10
63+
@test length(dc.any_du) == 0 # Initially empty
64+
65+
# Resize to larger
66+
resize!(dc, 20)
67+
@test length(dc.du) == 20
68+
69+
# Resize to smaller
70+
resize!(dc, 5)
71+
@test length(dc.du) == 5
72+
73+
# Test that it returns the cache itself
74+
@test resize!(dc, 8) === dc
75+
end
76+
77+
# Test resize! functionality for FixedSizeDiffCache
78+
@testset "resize! for FixedSizeDiffCache" begin
79+
u = rand(10)
80+
dc = FixedSizeDiffCache(u)
81+
82+
# Initial size
83+
@test length(dc.du) == 10
84+
@test length(dc.any_du) == 0 # Initially empty
85+
86+
# Resize to larger
87+
resize!(dc, 20)
88+
@test length(dc.du) == 20
89+
90+
# Resize to smaller
91+
resize!(dc, 5)
92+
@test length(dc.du) == 5
93+
94+
# Test that it returns the cache itself
95+
@test resize!(dc, 8) === dc
96+
end

0 commit comments

Comments
 (0)