Skip to content

Commit 501efcb

Browse files
Merge pull request #1222 from SciML/void
Use SciMLBase.Void on EnzymeVJP functions
2 parents b87d92b + 2ea940a commit 501efcb

File tree

4 files changed

+9
-9
lines changed

4 files changed

+9
-9
lines changed

src/adjoint_common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ function get_pf(autojacvec::ReverseDiffVJP; _f = nothing, isinplace = nothing,
489489
end
490490

491491
function get_pf(autojacvec::EnzymeVJP; _f, isinplace, isRODE)
492-
_f
492+
isinplace ? SciMLBase.Void(_f) : _f
493493
end
494494

495495
function get_pf(::MooncakeVJP, prob, _f)

src/derivative_wrappers.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -724,19 +724,19 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,
724724
# Correctness over speed
725725
# TODO: Get a fix for `remake_zero!` to allow reusing zero'd memory
726726
# https://github.com/EnzymeAD/Enzyme.jl/issues/2400
727-
_tmp6 = Enzyme.make_zero(f)
727+
_tmp6 = Enzyme.make_zero(SciMLBase.Void(f))
728728
else
729729
Enzyme.remake_zero!(_tmp6)
730730
end
731731

732732
if W === nothing
733-
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(f, _tmp6),
733+
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(SciMLBase.Void(f), _tmp6),
734734
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
735735
Enzyme.Duplicated(ytmp, tmp1),
736736
dup,
737737
Enzyme.Const(t))
738738
else
739-
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(f, _tmp6),
739+
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(SciMLBase.Void(f), _tmp6),
740740
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
741741
Enzyme.Duplicated(ytmp, tmp1),
742742
dup,

src/gauss_adjoint.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing)
415415
pf = nothing
416416
pJ = nothing
417417
elseif sensealg.autojacvec isa EnzymeVJP
418-
pf = unwrappedf
418+
pf = SciMLBase.isinplace(sol.prob.f) ? SciMLBase.Void(unwrappedf) : unwrappedf
419419
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
420420
pJ = nothing
421421
elseif sensealg.autojacvec isa MooncakeVJP
@@ -616,4 +616,4 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,
616616
end
617617

618618
__maybe_adjoint(x::AbstractArray) = x'
619-
__maybe_adjoint(x) = x
619+
__maybe_adjoint(x) = x

src/quadrature_adjoint.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing)
209209
pf = nothing
210210
pJ = nothing
211211
elseif sensealg.autojacvec isa EnzymeVJP
212-
pf = unwrappedf
212+
pf = SciMLBase.isinplace(sol.prob.f) ? SciMLBase.Void(unwrappedf) : unwrappedf
213213
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
214214
pJ = nothing
215215
elseif sensealg.autojacvec isa MooncakeVJP
@@ -293,7 +293,7 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand)
293293
if SciMLBase.isinplace(sol.prob.f)
294294
Enzyme.remake_zero!(tmp6)
295295
Enzyme.autodiff(
296-
Enzyme.Reverse, Enzyme.Duplicated(f, tmp6), Enzyme.Const,
296+
Enzyme.Reverse, Enzyme.Duplicated(SciMLBase.Void(f), tmp6), Enzyme.Const,
297297
Enzyme.Duplicated(tmp3, tmp4),
298298
Enzyme.Const(y), dup, Enzyme.Const(t))
299299
else
@@ -537,4 +537,4 @@ function _update_integrand_and_dgrad(res, sensealg::QuadratureAdjoint, cb, integ
537537
vecjacobian!(dλ, integrand.y, dλ, integrand.p, t, fakeS; dgrad = dgrad)
538538
res .-= dgrad
539539
return integrand
540-
end
540+
end

0 commit comments

Comments
 (0)