@@ -80,6 +80,8 @@ function ChainRulesCore.rrule(
80
80
newbuf = MTK. remake_buffer (indp, oldbuf, idxs, vals)
81
81
tunable_idxs = reduce (
82
82
vcat, (idx. idx for idx in idxs if idx. portion isa MTK. SciMLStructures. Tunable))
83
+ initials_idxs = reduce (
84
+ vcat, (idx. idx for idx in idxs if idx. portion isa MTK. SciMLStructures. Initials))
83
85
disc_idxs = subset_idxs (idxs, MTK. SciMLStructures. Discrete (), oldbuf. discrete)
84
86
const_idxs = subset_idxs (idxs, MTK. SciMLStructures. Constants (), oldbuf. constant)
85
87
nn_idxs = subset_idxs (idxs, MTK. NONNUMERIC_PORTION, oldbuf. nonnumeric)
@@ -91,10 +93,12 @@ function ChainRulesCore.rrule(
91
93
indp′ = NoTangent ()
92
94
93
95
tunable = selected_tangents (buf′. tunable, tunable_idxs)
96
+ initials = selected_tangents (buf′. initials, initials_idxs)
94
97
discrete = selected_tangents (buf′. discrete, disc_idxs)
95
98
constant = selected_tangents (buf′. constant, const_idxs)
96
99
nonnumeric = selected_tangents (buf′. nonnumeric, nn_idxs)
97
- oldbuf′ = Tangent {typeof(oldbuf)} (; tunable, discrete, constant, nonnumeric)
100
+ oldbuf′ = Tangent {typeof(oldbuf)} (;
101
+ tunable, initials, discrete, constant, nonnumeric)
98
102
idxs′ = NoTangent ()
99
103
vals′ = map (i -> MTK. _ducktyped_parameter_values (buf′, i), idxs)
100
104
return f′, indp′, oldbuf′, idxs′, vals′
0 commit comments