Skip to content

Commit 06cabdf

Browse files
committed
add repack/canonicalize in vec_pjac! to support SciMLStructs
1 parent 80934ae commit 06cabdf

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

src/gauss_adjoint.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -495,14 +495,15 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
495495
vtmp4 .= λ
496496
Enzyme.remake_zero!(tmp3)
497497
Enzyme.remake_zero!(out)
498-
498+
499+
dp = isscimlstructure(p) ? repack(out) : out
499500
if SciMLBase.isinplace(sol.prob.f)
500501
Enzyme.remake_zero!(tmp6)
501502

502503
Enzyme.autodiff(
503504
Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,
504505
Enzyme.Duplicated(tmp3, tmp4),
505-
Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t))
506+
Enzyme.Const(y), Enzyme.Duplicated(p, dp), Enzyme.Const(t))
506507
else
507508
function g(du, u, p, t)
508509
du .= f(u, p, t)
@@ -512,7 +513,10 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
512513
Enzyme.autodiff(
513514
Enzyme.Reverse, Enzyme.Duplicated(g, tmp6), Enzyme.Const,
514515
Enzyme.Duplicated(tmp3, tmp4),
515-
Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t))
516+
Enzyme.Const(y), Enzyme.Duplicated(p, dp), Enzyme.Const(t))
517+
end
518+
if isscimlstructure(p)
519+
out .+= canonicalize(Tunable(), dp)[1]
516520
end
517521
elseif sensealg.autojacvec isa MooncakeVJP
518522
_, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ)

0 commit comments

Comments
 (0)