Skip to content

Commit b9c74f8

Browse files
Merge pull request #1221 from SciML/remake_zero2
actually add remake_zero
2 parents 040b04c + 2115d7a commit b9c74f8

File tree

4 files changed

+16
-78
lines changed

4 files changed

+16
-78
lines changed

src/adjoint_common.jl

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -489,30 +489,7 @@ function get_pf(autojacvec::ReverseDiffVJP; _f = nothing, isinplace = nothing,
489489
end
490490

491491
function get_pf(autojacvec::EnzymeVJP; _f, isinplace, isRODE)
492-
pf = let f = _f
493-
if isinplace && isRODE
494-
function (out, u, _p, t, W)
495-
f(out, u, _p, t, W)
496-
nothing
497-
end
498-
elseif isinplace
499-
function (out, u, _p, t)
500-
f(out, u, _p, t)
501-
nothing
502-
end
503-
elseif !isinplace && isRODE
504-
function (out, u, _p, t, W)
505-
out .= f(u, _p, t, W)
506-
nothing
507-
end
508-
else
509-
# !isinplace
510-
function (out, u, _p, t)
511-
out .= f(u, _p, t)
512-
nothing
513-
end
514-
end
515-
end
492+
_f
516493
end
517494

518495
function get_pf(::MooncakeVJP, prob, _f)

src/derivative_wrappers.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -720,12 +720,15 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,
720720
isautojacvec = get_jacvec(sensealg)
721721

722722
if inplace_sensitivity(S)
723-
724-
# Correctness over speed
725-
# TODO: Get a fix for `remake_zero!` to allow reusing zero'd memory
726-
# https://github.com/EnzymeAD/Enzyme.jl/issues/2400
727-
_tmp6 = Enzyme.make_zero(f)
728-
723+
if S isa CallbackSensitivityFunction
724+
# Correctness over speed
725+
# TODO: Get a fix for `remake_zero!` to allow reusing zero'd memory
726+
# https://github.com/EnzymeAD/Enzyme.jl/issues/2400
727+
_tmp6 = Enzyme.make_zero(f)
728+
else
729+
Enzyme.remake_zero!(_tmp6)
730+
end
731+
729732
if W === nothing
730733
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(f, _tmp6),
731734
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),

src/gauss_adjoint.jl

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ end
1616
struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache,
1717
Alg <: GaussAdjoint,
1818
uType, SType, CPS, pType,
19-
fType <: DiffEqBase.AbstractDiffEqFunction,
19+
fType,
2020
GI <: GaussIntegrand,
2121
ICB} <: SensitivityFunction
2222
diffcache::C
@@ -82,7 +82,7 @@ function ODEGaussAdjointSensitivityFunction(
8282
nothing
8383
end
8484
diffcache, y = adjointdiffcache(
85-
g, sensealg, discrete, sol, dgdu, dgdp, sol.prob.f, alg;
85+
g, sensealg, discrete, sol, dgdu, dgdp, f, alg;
8686
quad = true)
8787
return ODEGaussAdjointSensitivityFunction(diffcache, sensealg, discrete,
8888
y, sol, checkpoint_sol, sol.prob, f, gaussint, integrating_cb)
@@ -415,20 +415,7 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing)
415415
pf = nothing
416416
pJ = nothing
417417
elseif sensealg.autojacvec isa EnzymeVJP
418-
pf = let f = unwrappedf
419-
if DiffEqBase.isinplace(prob)
420-
function (out, u, _p, t)
421-
f(out, u, _p, t)
422-
nothing
423-
end
424-
else
425-
!DiffEqBase.isinplace(prob)
426-
function (out, u, _p, t)
427-
out .= f(u, _p, t)
428-
nothing
429-
end
430-
end
431-
end
418+
pf = unwrappedf
432419
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
433420
pJ = nothing
434421
elseif sensealg.autojacvec isa MooncakeVJP
@@ -508,10 +495,7 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
508495
Enzyme.remake_zero!(out)
509496

510497
if SciMLBase.isinplace(sol.prob.f)
511-
# Correctness over speed
512-
# TODO: Get a fix for `remake_zero!` to allow reusing zero'd memory
513-
# https://github.com/EnzymeAD/Enzyme.jl/issues/2400
514-
tmp6 = Enzyme.make_zero(tmp6)
498+
Enzyme.remake_zero!(tmp6)
515499

516500
Enzyme.autodiff(
517501
Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,

src/quadrature_adjoint.jl

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -209,30 +209,7 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing)
209209
pf = nothing
210210
pJ = nothing
211211
elseif sensealg.autojacvec isa EnzymeVJP
212-
pf = let f = unwrappedf
213-
if DiffEqBase.isinplace(prob) && prob isa RODEProblem
214-
function (out, u, _p, t, W)
215-
f(out, u, _p, t, W)
216-
nothing
217-
end
218-
elseif DiffEqBase.isinplace(prob)
219-
function (out, u, _p, t)
220-
f(out, u, _p, t)
221-
nothing
222-
end
223-
elseif !DiffEqBase.isinplace(prob) && prob isa RODEProblem
224-
function (out, u, _p, t, W)
225-
out .= f(u, _p, t, W)
226-
nothing
227-
end
228-
else
229-
!DiffEqBase.isinplace(prob)
230-
function (out, u, _p, t)
231-
out .= f(u, _p, t)
232-
nothing
233-
end
234-
end
235-
end
212+
pf = unwrappedf
236213
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
237214
pJ = nothing
238215
elseif sensealg.autojacvec isa MooncakeVJP
@@ -314,10 +291,7 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand)
314291
end
315292

316293
if SciMLBase.isinplace(sol.prob.f)
317-
# Correctness over speed
318-
# TODO: Get a fix for `remake_zero!` to allow reusing zero'd memory
319-
# https://github.com/EnzymeAD/Enzyme.jl/issues/2400
320-
tmp6 = Enzyme.make_zero(f)
294+
Enzyme.remake_zero!(tmp6)
321295
Enzyme.autodiff(
322296
Enzyme.Reverse, Enzyme.Duplicated(f, tmp6), Enzyme.Const,
323297
Enzyme.Duplicated(tmp3, tmp4),

0 commit comments

Comments
 (0)