@@ -85,13 +85,19 @@ function ChainRulesCore.rrule(
8585 f′ = NoTangent ()
8686 indp′ = NoTangent ()
8787
88- tunable = selected_tangents (buf′. tunable, tunable_idxs)
89- discrete = selected_tangents (buf′. discrete, disc_idxs)
90- constant = selected_tangents (buf′. constant, const_idxs)
91- nonnumeric = selected_tangents (buf′. nonnumeric, nn_idxs)
88+ if buf′ isa AbstractArray
89+ tunable = selected_tangents (buf′, tunable_idxs)
90+ discrete = constant = nonnumeric = NoTangent ()
91+ vals′ = map (i -> buf′[i. idx], idxs)
92+ else
93+ tunable = selected_tangents (buf′. tunable, tunable_idxs)
94+ discrete = selected_tangents (buf′. discrete, disc_idxs)
95+ constant = selected_tangents (buf′. constant, const_idxs)
96+ nonnumeric = selected_tangents (buf′. nonnumeric, nn_idxs)
97+ vals′ = map (i -> MTK. _ducktyped_parameter_values (buf′, i), idxs)
98+ end
9299 oldbuf′ = Tangent {typeof(oldbuf)} (; tunable, discrete, constant, nonnumeric)
93100 idxs′ = NoTangent ()
94- vals′ = map (i -> MTK. _ducktyped_parameter_values (buf′, i), idxs)
95101 return f′, indp′, oldbuf′, idxs′, vals′
96102 end
97103 end
0 commit comments