@@ -80,6 +80,8 @@ function ChainRulesCore.rrule(
8080 newbuf = MTK. remake_buffer (indp, oldbuf, idxs, vals)
8181 tunable_idxs = reduce (
8282 vcat, (idx. idx for idx in idxs if idx. portion isa MTK. SciMLStructures. Tunable))
83+ initials_idxs = reduce (
84+ vcat, (idx. idx for idx in idxs if idx. portion isa MTK. SciMLStructures. Initials))
8385 disc_idxs = subset_idxs (idxs, MTK. SciMLStructures. Discrete (), oldbuf. discrete)
8486 const_idxs = subset_idxs (idxs, MTK. SciMLStructures. Constants (), oldbuf. constant)
8587 nn_idxs = subset_idxs (idxs, MTK. NONNUMERIC_PORTION, oldbuf. nonnumeric)
@@ -91,10 +93,12 @@ function ChainRulesCore.rrule(
9193 indp′ = NoTangent ()
9294
9395 tunable = selected_tangents (buf′. tunable, tunable_idxs)
96+ initials = selected_tangents (buf′. initials, initials_idxs)
9497 discrete = selected_tangents (buf′. discrete, disc_idxs)
9598 constant = selected_tangents (buf′. constant, const_idxs)
9699 nonnumeric = selected_tangents (buf′. nonnumeric, nn_idxs)
97- oldbuf′ = Tangent {typeof(oldbuf)} (; tunable, discrete, constant, nonnumeric)
100+ oldbuf′ = Tangent {typeof(oldbuf)} (;
101+ tunable, initials, discrete, constant, nonnumeric)
98102 idxs′ = NoTangent ()
99103 vals′ = map (i -> MTK. _ducktyped_parameter_values (buf′, i), idxs)
100104 return f′, indp′, oldbuf′, idxs′, vals′
0 commit comments