Skip to content

Commit 4a0ea26

Browse files
fix function wrap
1 parent 04f0a7b commit 4a0ea26

File tree

3 files changed

+5
-64
lines changed

3 files changed

+5
-64
lines changed

src/adjoint_common.jl

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -489,30 +489,7 @@ function get_pf(autojacvec::ReverseDiffVJP; _f = nothing, isinplace = nothing,
489489
end
490490

491491
function get_pf(autojacvec::EnzymeVJP; _f, isinplace, isRODE)
492-
pf = let f = _f
493-
if isinplace && isRODE
494-
function (out, u, _p, t, W)
495-
f(out, u, _p, t, W)
496-
nothing
497-
end
498-
elseif isinplace
499-
function (out, u, _p, t)
500-
f(out, u, _p, t)
501-
nothing
502-
end
503-
elseif !isinplace && isRODE
504-
function (out, u, _p, t, W)
505-
out .= f(u, _p, t, W)
506-
nothing
507-
end
508-
else
509-
# !isinplace
510-
function (out, u, _p, t)
511-
out .= f(u, _p, t)
512-
nothing
513-
end
514-
end
515-
end
492+
_f
516493
end
517494

518495
function get_pf(::MooncakeVJP, prob, _f)

src/gauss_adjoint.jl

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ end
1616
struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache,
1717
Alg <: GaussAdjoint,
1818
uType, SType, CPS, pType,
19-
fType <: DiffEqBase.AbstractDiffEqFunction,
19+
fType,
2020
GI <: GaussIntegrand,
2121
ICB} <: SensitivityFunction
2222
diffcache::C
@@ -82,7 +82,7 @@ function ODEGaussAdjointSensitivityFunction(
8282
nothing
8383
end
8484
diffcache, y = adjointdiffcache(
85-
g, sensealg, discrete, sol, dgdu, dgdp, sol.prob.f, alg;
85+
g, sensealg, discrete, sol, dgdu, dgdp, f, alg;
8686
quad = true)
8787
return ODEGaussAdjointSensitivityFunction(diffcache, sensealg, discrete,
8888
y, sol, checkpoint_sol, sol.prob, f, gaussint, integrating_cb)
@@ -415,20 +415,7 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing)
415415
pf = nothing
416416
pJ = nothing
417417
elseif sensealg.autojacvec isa EnzymeVJP
418-
pf = let f = unwrappedf
419-
if DiffEqBase.isinplace(prob)
420-
function (out, u, _p, t)
421-
f(out, u, _p, t)
422-
nothing
423-
end
424-
else
425-
!DiffEqBase.isinplace(prob)
426-
function (out, u, _p, t)
427-
out .= f(u, _p, t)
428-
nothing
429-
end
430-
end
431-
end
418+
pf = unwrappedf
432419
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
433420
pJ = nothing
434421
elseif sensealg.autojacvec isa MooncakeVJP

src/quadrature_adjoint.jl

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -209,30 +209,7 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing)
209209
pf = nothing
210210
pJ = nothing
211211
elseif sensealg.autojacvec isa EnzymeVJP
212-
pf = let f = unwrappedf
213-
if DiffEqBase.isinplace(prob) && prob isa RODEProblem
214-
function (out, u, _p, t, W)
215-
f(out, u, _p, t, W)
216-
nothing
217-
end
218-
elseif DiffEqBase.isinplace(prob)
219-
function (out, u, _p, t)
220-
f(out, u, _p, t)
221-
nothing
222-
end
223-
elseif !DiffEqBase.isinplace(prob) && prob isa RODEProblem
224-
function (out, u, _p, t, W)
225-
out .= f(u, _p, t, W)
226-
nothing
227-
end
228-
else
229-
!DiffEqBase.isinplace(prob)
230-
function (out, u, _p, t)
231-
out .= f(u, _p, t)
232-
nothing
233-
end
234-
end
235-
end
212+
pf = unwrappedf
236213
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
237214
pJ = nothing
238215
elseif sensealg.autojacvec isa MooncakeVJP

0 commit comments

Comments
 (0)