@@ -382,15 +382,15 @@ function _setup_reverse_callbacks(
382382
383383 if sensealg isa GaussAdjoint
384384 vecjacobian! (
385- dgrad, integrator . p ,
385+ nothing , y ,
386386 integrator. f. f. integrating_cb. affect!. integrand_values. integrand,
387- y , integrator. t, fakeSp; dgrad = nothing , dy = nothing
387+ integrator . p , integrator. t, fakeSp; dgrad = dgrad , dy = nothing
388388 )
389389 integrator. f. f. integrating_cb. affect!. integrand_values. integrand .= dgrad
390390 else
391391 vecjacobian! (
392- dgrad, integrator . p , grad, y , integrator. t, fakeSp;
393- dgrad = nothing , dy = nothing
392+ nothing , y , grad, integrator . p , integrator. t, fakeSp;
393+ dgrad = dgrad , dy = nothing
394394 )
395395 grad .= dgrad
396396 end
@@ -483,7 +483,7 @@ mutable struct CallbackAffectPWrapper{cbType, AJV, EI, T} <: Function
483483 tprev:: T
484484end
485485
486- function (ff:: CallbackAffectPWrapper )(dp, p, u , t)
486+ function (ff:: CallbackAffectPWrapper )(dp, u, p , t)
487487 _affect! = get_affect! (ff. cb, ff. pos_neg)
488488 fakeinteg = get_FakeIntegrator (ff. autojacvec, u, p, t, ff. tprev)
489489 if ff. cb isa VectorContinuousCallback
@@ -510,6 +510,30 @@ function get_FakeIntegrator(autojacvec::ReverseDiffVJP, u, p, t, tprev)
510510end
511511get_FakeIntegrator (autojacvec:: EnzymeVJP , u, p, t, tprev) = FakeIntegrator (u, p, t, tprev)
512512
513+ function _get_wp_paramjac_config (autojacvec:: EnzymeVJP , _p, wp, y, __p, _t)
514+ return (zero (y), zero (_p), zero (_p), zero (_p), zero (y))
515+ end
516+
517+ function _get_wp_paramjac_config (autojacvec:: ReverseDiffVJP , _p, wp, y, __p, _t)
518+ if _p === nothing || _p isa SciMLBase. NullParameters
519+ tunables, repack = _p, identity
520+ else
521+ tunables, repack, aliases = canonicalize (Tunable (), _p)
522+ end
523+ tunables_inner = tunables
524+ tape = ReverseDiff. GradientTape ((y, tunables_inner, [_t])) do u, p, t
525+ dp1 = similar (p, length (p))
526+ dp1 .= false
527+ wp (dp1, u, repack (p), first (t))
528+ return vec (dp1)
529+ end
530+ if compile_tape (autojacvec)
531+ return ReverseDiff. compile (tape)
532+ else
533+ return tape
534+ end
535+ end
536+
513537function get_cb_diffcaches (
514538 cb:: Union {
515539 DiscreteCallback, ContinuousCallback,
@@ -562,11 +586,8 @@ function get_cb_diffcaches(
562586 nothing , nothing , nothing , false
563587 )
564588
565- paramjac_config = get_paramjac_config (
566- autojacvec, y, wp, _p, y, _t;
567- numindvar = length (y), alg = nothing ,
568- isinplace = true , isRODE = false ,
569- _W = nothing
589+ paramjac_config = _get_wp_paramjac_config (
590+ autojacvec, _p, wp, y, _p, _t
570591 )
571592 pf = get_pf (autojacvec; _f = wp, isinplace = true , isRODE = false )
572593 if autojacvec isa EnzymeVJP
0 commit comments