|
16 | 16 | struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache, |
17 | 17 | Alg <: GaussAdjoint, |
18 | 18 | uType, SType, CPS, pType, |
19 | | - fType <: DiffEqBase.AbstractDiffEqFunction, |
| 19 | + fType, |
20 | 20 | GI <: GaussIntegrand, |
21 | 21 | ICB} <: SensitivityFunction |
22 | 22 | diffcache::C |
@@ -82,7 +82,7 @@ function ODEGaussAdjointSensitivityFunction( |
82 | 82 | nothing |
83 | 83 | end |
84 | 84 | diffcache, y = adjointdiffcache( |
85 | | - g, sensealg, discrete, sol, dgdu, dgdp, sol.prob.f, alg; |
| 85 | + g, sensealg, discrete, sol, dgdu, dgdp, f, alg; |
86 | 86 | quad = true) |
87 | 87 | return ODEGaussAdjointSensitivityFunction(diffcache, sensealg, discrete, |
88 | 88 | y, sol, checkpoint_sol, sol.prob, f, gaussint, integrating_cb) |
@@ -415,20 +415,7 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing) |
415 | 415 | pf = nothing |
416 | 416 | pJ = nothing |
417 | 417 | elseif sensealg.autojacvec isa EnzymeVJP |
418 | | - pf = let f = unwrappedf |
419 | | - if DiffEqBase.isinplace(prob) |
420 | | - function (out, u, _p, t) |
421 | | - f(out, u, _p, t) |
422 | | - nothing |
423 | | - end |
424 | | - else |
425 | | - !DiffEqBase.isinplace(prob) |
426 | | - function (out, u, _p, t) |
427 | | - out .= f(u, _p, t) |
428 | | - nothing |
429 | | - end |
430 | | - end |
431 | | - end |
| 418 | + pf = unwrappedf |
432 | 419 | paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf) |
433 | 420 | pJ = nothing |
434 | 421 | elseif sensealg.autojacvec isa MooncakeVJP |
|
0 commit comments