@@ -2,14 +2,102 @@ module MTKChainRulesCoreExt
22
33import  ModelingToolkit as MTK
44import  ChainRulesCore
5- import  ChainRulesCore:  NoTangent
5+ import  ChainRulesCore:  Tangent, ZeroTangent,  NoTangent, zero_tangent, unthunk 
66
77function  ChainRulesCore. rrule (:: Type{MTK.MTKParameters} , tunables, args... )
88    function  mtp_pullback (dt)
9+         dt =  unthunk (dt)
910        (NoTangent (), dt. tunable[1 : length (tunables)],
1011            ntuple (_ ->  NoTangent (), length (args))... )
1112    end 
1213    MTK. MTKParameters (tunables, args... ), mtp_pullback
1314end 
1415
16+ notangent_or_else (:: NoTangent , _, x) =  x
17+ notangent_or_else (_, x, _) =  x
18+ notangent_fallback (x, y) =  notangent_or_else (x, x, y)
19+ reduce_to_notangent (x, y) =  notangent_or_else (x, y, x)
20+ 
21+ function  subset_idxs (idxs, portion, template)
22+     ntuple (Val (length (template))) do  subi
23+         [Base. tail (idx. idx) for  idx in  idxs if  idx. portion ==  portion &&  idx. idx[1 ] ==  subi]
24+     end 
25+ end 
26+ 
27+ selected_tangents (:: NoTangent , _) =  ()
28+ selected_tangents (:: ZeroTangent , _) =  ZeroTangent ()
29+ function  selected_tangents (
30+         tangents:: AbstractArray{T} , idxs:: Vector{Tuple{Int}} ) where  {T <:  Number }
31+     selected_tangents (tangents, map (only, idxs))
32+ end 
33+ function  selected_tangents (tangents:: AbstractArray{T} , idxs... ) where  {T <:  Number }
34+     newtangents =  copy (tangents)
35+     view (newtangents, idxs... ) .=  zero (T)
36+     newtangents
37+ end 
38+ function  selected_tangents (
39+         tangents:: AbstractVector{T} , idxs) where  {S <:  Number , T <:  AbstractArray{S} }
40+     newtangents =  copy (tangents)
41+     for  i in  idxs
42+         j, k...  =  i
43+         if  k ==  ()
44+             newtangents[j] =  zero (newtangents[j])
45+         else 
46+             newtangents[j] =  selected_tangents (newtangents[j], k... )
47+         end 
48+     end 
49+     newtangents
50+ end 
51+ function  selected_tangents (tangents:: AbstractVector{T} , idxs) where  {T <:  AbstractArray }
52+     newtangents =  similar (tangents, Union{T, NoTangent})
53+     copyto! (newtangents, tangents)
54+     for  i in  idxs
55+         j, k...  =  i
56+         if  k ==  ()
57+             newtangents[j] =  NoTangent ()
58+         else 
59+             newtangents[j] =  selected_tangents (newtangents[j], k... )
60+         end 
61+     end 
62+     newtangents
63+ end 
64+ function  selected_tangents (
65+         tangents:: Union{Tangent{<:Tuple}, Tangent{T, <:Tuple}} , idxs) where  {T}
66+     ntuple (Val (length (tangents))) do  i
67+         selected_tangents (tangents[i], idxs[i])
68+     end 
69+ end 
70+ 
71+ function  ChainRulesCore. rrule (
72+         :: typeof (MTK. remake_buffer), indp, oldbuf:: MTK.MTKParameters , idxs, vals)
73+     if  idxs isa  AbstractSet
74+         idxs =  collect (idxs)
75+     end 
76+     idxs =  map (idxs) do  i
77+         i isa  MTK. ParameterIndex ?  i :  MTK. parameter_index (indp, i)
78+     end 
79+     newbuf =  MTK. remake_buffer (indp, oldbuf, idxs, vals)
80+     tunable_idxs =  reduce (
81+         vcat, (idx. idx for  idx in  idxs if  idx. portion isa  MTK. SciMLStructures. Tunable))
82+     disc_idxs =  subset_idxs (idxs, MTK. SciMLStructures. Discrete (), oldbuf. discrete)
83+     const_idxs =  subset_idxs (idxs, MTK. SciMLStructures. Constants (), oldbuf. constant)
84+     nn_idxs =  subset_idxs (idxs, MTK. NONNUMERIC_PORTION, oldbuf. nonnumeric)
85+ 
86+     function  remake_buffer_pullback (buf′)
87+         buf′ =  unthunk (buf′)
88+         f′ =  NoTangent ()
89+         indp′ =  NoTangent ()
90+ 
91+         tunable =  selected_tangents (buf′. tunable, tunable_idxs)
92+         discrete =  selected_tangents (buf′. discrete, disc_idxs)
93+         constant =  selected_tangents (buf′. constant, const_idxs)
94+         nonnumeric =  selected_tangents (buf′. nonnumeric, nn_idxs)
95+         oldbuf′ =  Tangent {typeof(oldbuf)} (; tunable, discrete, constant, nonnumeric)
96+         idxs′ =  NoTangent ()
97+         vals′ =  map (i ->  MTK. _ducktyped_parameter_values (buf′, i), idxs)
98+         return  f′, indp′, oldbuf′, idxs′, vals′
99+     end 
100+     newbuf, remake_buffer_pullback
101+ end 
102+ 
15103end 
0 commit comments