@@ -13,11 +13,6 @@ function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
1313 MTK. MTKParameters (tunables, args... ), mtp_pullback
1414end
1515
16- notangent_or_else (:: NoTangent , _, x) = x
17- notangent_or_else (_, x, _) = x
18- notangent_fallback (x, y) = notangent_or_else (x, x, y)
19- reduce_to_notangent (x, y) = notangent_or_else (x, y, x)
20-
2116function subset_idxs (idxs, portion, template)
2217 ntuple (Val (length (template))) do subi
2318 [Base. tail (idx. idx) for idx in idxs if idx. portion == portion && idx. idx[1 ] == subi]
@@ -83,21 +78,23 @@ function ChainRulesCore.rrule(
8378 const_idxs = subset_idxs (idxs, MTK. SciMLStructures. Constants (), oldbuf. constant)
8479 nn_idxs = subset_idxs (idxs, MTK. NONNUMERIC_PORTION, oldbuf. nonnumeric)
8580
86- function remake_buffer_pullback (buf′)
87- buf′ = unthunk (buf′)
88- f′ = NoTangent ()
89- indp′ = NoTangent ()
81+ pullback = let idxs = idxs
82+ function remake_buffer_pullback (buf′)
83+ buf′ = unthunk (buf′)
84+ f′ = NoTangent ()
85+ indp′ = NoTangent ()
9086
91- tunable = selected_tangents (buf′. tunable, tunable_idxs)
92- discrete = selected_tangents (buf′. discrete, disc_idxs)
93- constant = selected_tangents (buf′. constant, const_idxs)
94- nonnumeric = selected_tangents (buf′. nonnumeric, nn_idxs)
95- oldbuf′ = Tangent {typeof(oldbuf)} (; tunable, discrete, constant, nonnumeric)
96- idxs′ = NoTangent ()
97- vals′ = map (i -> MTK. _ducktyped_parameter_values (buf′, i), idxs)
98- return f′, indp′, oldbuf′, idxs′, vals′
87+ tunable = selected_tangents (buf′. tunable, tunable_idxs)
88+ discrete = selected_tangents (buf′. discrete, disc_idxs)
89+ constant = selected_tangents (buf′. constant, const_idxs)
90+ nonnumeric = selected_tangents (buf′. nonnumeric, nn_idxs)
91+ oldbuf′ = Tangent {typeof(oldbuf)} (; tunable, discrete, constant, nonnumeric)
92+ idxs′ = NoTangent ()
93+ vals′ = map (i -> MTK. _ducktyped_parameter_values (buf′, i), idxs)
94+ return f′, indp′, oldbuf′, idxs′, vals′
95+ end
9996 end
100- newbuf, remake_buffer_pullback
97+ newbuf, pullback
10198end
10299
103100end
0 commit comments