Skip to content

Commit 9dfbfdc

Browse files
Merge pull request #1191 from jClugstor/enzyme_adjoint_lbc
use LazyBufferCache instead of FixedSizeDiffCache
2 parents cf64430 + a55f1de commit 9dfbfdc

File tree

3 files changed

+10
-11
lines changed

3 files changed

+10
-11
lines changed

src/SciMLSensitivity.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using FunctionWrappersWrappers: FunctionWrappersWrappers
1616
using GPUArraysCore: GPUArraysCore
1717
using LinearSolve: LinearSolve
1818
using PreallocationTools: PreallocationTools, dualcache, get_tmp, DiffCache,
19-
FixedSizeDiffCache
19+
FixedSizeDiffCache, LazyBufferCache
2020
using RandomNumbers: Xorshifts
2121
using RecursiveArrayTools: RecursiveArrayTools, AbstractDiffEqArray,
2222
AbstractVectorOfArray, ArrayPartition, DiffEqArray,

src/adjoint_common.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -444,10 +444,10 @@ function get_paramjac_config(autojacvec::EnzymeVJP, p::SciMLBase.NullParameters,
444444
autojacvec.chunksize
445445
end
446446

447-
paramjac_config = FixedSizeDiffCache(zero(y), chunk), p,
448-
FixedSizeDiffCache(zero(y), chunk),
449-
FixedSizeDiffCache(zero(y), chunk),
450-
FixedSizeDiffCache(zero(y), chunk)
447+
paramjac_config = LazyBufferCache(), p,
448+
LazyBufferCache(),
449+
LazyBufferCache(),
450+
LazyBufferCache()
451451
else
452452
paramjac_config = zero(y), p, zero(y), zero(y), zero(y)
453453
end
@@ -464,11 +464,10 @@ function get_paramjac_config(autojacvec::EnzymeVJP, p, f, y, _p, _t; numindvar,
464464
autojacvec.chunksize
465465
end
466466

467-
paramjac_config = FixedSizeDiffCache(zero(y), chunk),
468-
zero(_p),
469-
FixedSizeDiffCache(zero(y), chunk),
470-
FixedSizeDiffCache(zero(y), chunk),
471-
FixedSizeDiffCache(zero(y), chunk)
467+
paramjac_config = LazyBufferCache(), zero(_p),
468+
LazyBufferCache(),
469+
LazyBufferCache(),
470+
LazyBufferCache()
472471
else
473472
paramjac_config = zero(y), zero(_p), zero(y), zero(y), zero(y)
474473
end

src/derivative_wrappers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,
677677

678678
_tmp1, tmp2, _tmp3, _tmp4, _tmp5, _tmp6 = S.diffcache.paramjac_config
679679

680-
if _tmp1 isa FixedSizeDiffCache
680+
if _tmp1 isa LazyBufferCache
681681
tmp1 = get_tmp(_tmp1, dλ)
682682
tmp3 = get_tmp(_tmp3, dλ)
683683
tmp4 = get_tmp(_tmp4, dλ)

0 commit comments

Comments
 (0)