@@ -13,11 +13,6 @@ function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
1313    MTK. MTKParameters (tunables, args... ), mtp_pullback
1414end 
1515
16- notangent_or_else (:: NoTangent , _, x) =  x
17- notangent_or_else (_, x, _) =  x
18- notangent_fallback (x, y) =  notangent_or_else (x, x, y)
19- reduce_to_notangent (x, y) =  notangent_or_else (x, y, x)
20- 
2116function  subset_idxs (idxs, portion, template)
2217    ntuple (Val (length (template))) do  subi
2318        [Base. tail (idx. idx) for  idx in  idxs if  idx. portion ==  portion &&  idx. idx[1 ] ==  subi]
@@ -83,21 +78,23 @@ function ChainRulesCore.rrule(
8378    const_idxs =  subset_idxs (idxs, MTK. SciMLStructures. Constants (), oldbuf. constant)
8479    nn_idxs =  subset_idxs (idxs, MTK. NONNUMERIC_PORTION, oldbuf. nonnumeric)
8580
86-     function  remake_buffer_pullback (buf′)
87-         buf′ =  unthunk (buf′)
88-         f′ =  NoTangent ()
89-         indp′ =  NoTangent ()
81+     pullback =  let  idxs =  idxs
82+         function  remake_buffer_pullback (buf′)
83+             buf′ =  unthunk (buf′)
84+             f′ =  NoTangent ()
85+             indp′ =  NoTangent ()
9086
91-         tunable =  selected_tangents (buf′. tunable, tunable_idxs)
92-         discrete =  selected_tangents (buf′. discrete, disc_idxs)
93-         constant =  selected_tangents (buf′. constant, const_idxs)
94-         nonnumeric =  selected_tangents (buf′. nonnumeric, nn_idxs)
95-         oldbuf′ =  Tangent {typeof(oldbuf)} (; tunable, discrete, constant, nonnumeric)
96-         idxs′ =  NoTangent ()
97-         vals′ =  map (i ->  MTK. _ducktyped_parameter_values (buf′, i), idxs)
98-         return  f′, indp′, oldbuf′, idxs′, vals′
87+             tunable =  selected_tangents (buf′. tunable, tunable_idxs)
88+             discrete =  selected_tangents (buf′. discrete, disc_idxs)
89+             constant =  selected_tangents (buf′. constant, const_idxs)
90+             nonnumeric =  selected_tangents (buf′. nonnumeric, nn_idxs)
91+             oldbuf′ =  Tangent {typeof(oldbuf)} (; tunable, discrete, constant, nonnumeric)
92+             idxs′ =  NoTangent ()
93+             vals′ =  map (i ->  MTK. _ducktyped_parameter_values (buf′, i), idxs)
94+             return  f′, indp′, oldbuf′, idxs′, vals′
95+         end 
9996    end 
100-     newbuf, remake_buffer_pullback 
97+     newbuf, pullback 
10198end 
10299
103100end 
0 commit comments