Skip to content

Commit 3172d7f

Browse files
Merge pull request #1348 from ChrisRackauckas-Claude/normalize-callback-vjp-args
Normalize CallbackAffectPWrapper VJP argument ordering
2 parents 157a78f + 352e2a3 commit 3172d7f

File tree

3 files changed

+39
-19
lines changed

3 files changed

+39
-19
lines changed

src/callback_tracking.jl

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -382,15 +382,15 @@ function _setup_reverse_callbacks(
382382

383383
if sensealg isa GaussAdjoint
384384
vecjacobian!(
385-
dgrad, integrator.p,
385+
nothing, y,
386386
integrator.f.f.integrating_cb.affect!.integrand_values.integrand,
387-
y, integrator.t, fakeSp; dgrad = nothing, dy = nothing
387+
integrator.p, integrator.t, fakeSp; dgrad = dgrad, dy = nothing
388388
)
389389
integrator.f.f.integrating_cb.affect!.integrand_values.integrand .= dgrad
390390
else
391391
vecjacobian!(
392-
dgrad, integrator.p, grad, y, integrator.t, fakeSp;
393-
dgrad = nothing, dy = nothing
392+
nothing, y, grad, integrator.p, integrator.t, fakeSp;
393+
dgrad = dgrad, dy = nothing
394394
)
395395
grad .= dgrad
396396
end
@@ -483,7 +483,7 @@ mutable struct CallbackAffectPWrapper{cbType, AJV, EI, T} <: Function
483483
tprev::T
484484
end
485485

486-
function (ff::CallbackAffectPWrapper)(dp, p, u, t)
486+
function (ff::CallbackAffectPWrapper)(dp, u, p, t)
487487
_affect! = get_affect!(ff.cb, ff.pos_neg)
488488
fakeinteg = get_FakeIntegrator(ff.autojacvec, u, p, t, ff.tprev)
489489
if ff.cb isa VectorContinuousCallback
@@ -510,6 +510,30 @@ function get_FakeIntegrator(autojacvec::ReverseDiffVJP, u, p, t, tprev)
510510
end
511511
get_FakeIntegrator(autojacvec::EnzymeVJP, u, p, t, tprev) = FakeIntegrator(u, p, t, tprev)
512512

513+
function _get_wp_paramjac_config(autojacvec::EnzymeVJP, _p, wp, y, __p, _t)
514+
return (zero(y), zero(_p), zero(_p), zero(_p), zero(y))
515+
end
516+
517+
function _get_wp_paramjac_config(autojacvec::ReverseDiffVJP, _p, wp, y, __p, _t)
518+
if _p === nothing || _p isa SciMLBase.NullParameters
519+
tunables, repack = _p, identity
520+
else
521+
tunables, repack, aliases = canonicalize(Tunable(), _p)
522+
end
523+
tunables_inner = tunables
524+
tape = ReverseDiff.GradientTape((y, tunables_inner, [_t])) do u, p, t
525+
dp1 = similar(p, length(p))
526+
dp1 .= false
527+
wp(dp1, u, repack(p), first(t))
528+
return vec(dp1)
529+
end
530+
if compile_tape(autojacvec)
531+
return ReverseDiff.compile(tape)
532+
else
533+
return tape
534+
end
535+
end
536+
513537
function get_cb_diffcaches(
514538
cb::Union{
515539
DiscreteCallback, ContinuousCallback,
@@ -562,11 +586,8 @@ function get_cb_diffcaches(
562586
nothing, nothing, nothing, false
563587
)
564588

565-
paramjac_config = get_paramjac_config(
566-
autojacvec, y, wp, _p, y, _t;
567-
numindvar = length(y), alg = nothing,
568-
isinplace = true, isRODE = false,
569-
_W = nothing
589+
paramjac_config = _get_wp_paramjac_config(
590+
autojacvec, _p, wp, y, _p, _t
570591
)
571592
pf = get_pf(autojacvec; _f = wp, isinplace = true, isRODE = false)
572593
if autojacvec isa EnzymeVJP

src/derivative_wrappers.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,10 @@ function _vecjacobian!(
492492
_y = eltype(y) === eltype(λ) ? y : convert.(promote_type(eltype(y), eltype(λ)), y)
493493
if W === nothing
494494
_tunables, _repack, _ = canonicalize(Tunable(), _p)
495+
_is_pswap = TS <: CallbackSensitivityFunctionPSwap
495496
tape = ReverseDiff.GradientTape((_y, _tunables, [t])) do u, p, t
496-
du1 = similar(u, size(u))
497+
du1 = _is_pswap ? similar(p, length(p)) : similar(u, size(u))
498+
du1 .= false
497499
f(du1, u, _repack(p), first(t))
498500
return vec(du1)
499501
end

src/quadrature_adjoint.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -548,14 +548,11 @@ function _update_integrand_and_dgrad(
548548

549549
_p = similar(integrand.p, size(integrand.p))
550550
_p .= false
551-
wp(_p, integrand.p, integrand.y, t)
551+
wp(_p, integrand.y, integrand.p, t)
552552

553553
if _p != integrand.p
554-
paramjac_config = get_paramjac_config(
555-
sensealg.autojacvec, integrand.y, wp, _p, integrand.y, t;
556-
numindvar = length(integrand.y), alg = nothing,
557-
isinplace = true, isRODE = false,
558-
_W = nothing
554+
paramjac_config = _get_wp_paramjac_config(
555+
sensealg.autojacvec, integrand.p, wp, integrand.y, integrand.p, t
559556
)
560557
pf = get_pf(sensealg.autojacvec; _f = wp, isinplace = true, isRODE = false)
561558
if sensealg.autojacvec isa EnzymeVJP
@@ -572,8 +569,8 @@ function _update_integrand_and_dgrad(
572569
fakeSp = CallbackSensitivityFunctionPSwap(wp, sensealg, diffcache_wp, sol.prob)
573570
#vjp with Jacobin given by dw/dp before event and vector given by grad
574571
vecjacobian!(
575-
res, integrand.p, res, integrand.y, t, fakeSp;
576-
dgrad = nothing, dy = nothing
572+
nothing, integrand.y, res, integrand.p, t, fakeSp;
573+
dgrad = res, dy = nothing
577574
)
578575
integrand = update_p_integrand(integrand, _p)
579576
end

0 commit comments

Comments
 (0)