Skip to content

Commit 43c3ccb

Browse files
fix: support initials in adjoints
1 parent d7da992 commit 43c3ccb

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

ext/MTKChainRulesCoreExt.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ function ChainRulesCore.rrule(
8080
newbuf = MTK.remake_buffer(indp, oldbuf, idxs, vals)
8181
tunable_idxs = reduce(
8282
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable))
83+
initials_idxs = reduce(
84+
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Initials))
8385
disc_idxs = subset_idxs(idxs, MTK.SciMLStructures.Discrete(), oldbuf.discrete)
8486
const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant)
8587
nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric)
@@ -91,10 +93,12 @@ function ChainRulesCore.rrule(
9193
indp′ = NoTangent()
9294

9395
tunable = selected_tangents(buf′.tunable, tunable_idxs)
96+
initials = selected_tangents(buf′.initials, initials_idxs)
9497
discrete = selected_tangents(buf′.discrete, disc_idxs)
9598
constant = selected_tangents(buf′.constant, const_idxs)
9699
nonnumeric = selected_tangents(buf′.nonnumeric, nn_idxs)
97-
oldbuf′ = Tangent{typeof(oldbuf)}(; tunable, discrete, constant, nonnumeric)
100+
oldbuf′ = Tangent{typeof(oldbuf)}(;
101+
tunable, initials, discrete, constant, nonnumeric)
98102
idxs′ = NoTangent()
99103
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
100104
return f′, indp′, oldbuf′, idxs′, vals′

0 commit comments

Comments
 (0)