@@ -2,14 +2,102 @@ module MTKChainRulesCoreExt
22
33import ModelingToolkit as MTK
44import ChainRulesCore
5- import ChainRulesCore: NoTangent
5+ import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk
66
77function ChainRulesCore. rrule (:: Type{MTK.MTKParameters} , tunables, args... )
88 function mtp_pullback (dt)
9+ dt = unthunk (dt)
910 (NoTangent (), dt. tunable[1 : length (tunables)],
1011 ntuple (_ -> NoTangent (), length (args))... )
1112 end
1213 MTK. MTKParameters (tunables, args... ), mtp_pullback
1314end
1415
16+ notangent_or_else (:: NoTangent , _, x) = x
17+ notangent_or_else (_, x, _) = x
18+ notangent_fallback (x, y) = notangent_or_else (x, x, y)
19+ reduce_to_notangent (x, y) = notangent_or_else (x, y, x)
20+
21+ function subset_idxs (idxs, portion, template)
22+ ntuple (Val (length (template))) do subi
23+ [Base. tail (idx. idx) for idx in idxs if idx. portion == portion && idx. idx[1 ] == subi]
24+ end
25+ end
26+
27+ selected_tangents (:: NoTangent , _) = ()
28+ selected_tangents (:: ZeroTangent , _) = ZeroTangent ()
29+ function selected_tangents (
30+ tangents:: AbstractArray{T} , idxs:: Vector{Tuple{Int}} ) where {T <: Number }
31+ selected_tangents (tangents, map (only, idxs))
32+ end
33+ function selected_tangents (tangents:: AbstractArray{T} , idxs... ) where {T <: Number }
34+ newtangents = copy (tangents)
35+ view (newtangents, idxs... ) .= zero (T)
36+ newtangents
37+ end
38+ function selected_tangents (
39+ tangents:: AbstractVector{T} , idxs) where {S <: Number , T <: AbstractArray{S} }
40+ newtangents = copy (tangents)
41+ for i in idxs
42+ j, k... = i
43+ if k == ()
44+ newtangents[j] = zero (newtangents[j])
45+ else
46+ newtangents[j] = selected_tangents (newtangents[j], k... )
47+ end
48+ end
49+ newtangents
50+ end
51+ function selected_tangents (tangents:: AbstractVector{T} , idxs) where {T <: AbstractArray }
52+ newtangents = similar (tangents, Union{T, NoTangent})
53+ copyto! (newtangents, tangents)
54+ for i in idxs
55+ j, k... = i
56+ if k == ()
57+ newtangents[j] = NoTangent ()
58+ else
59+ newtangents[j] = selected_tangents (newtangents[j], k... )
60+ end
61+ end
62+ newtangents
63+ end
64+ function selected_tangents (
65+ tangents:: Union{Tangent{<:Tuple}, Tangent{T, <:Tuple}} , idxs) where {T}
66+ ntuple (Val (length (tangents))) do i
67+ selected_tangents (tangents[i], idxs[i])
68+ end
69+ end
70+
71+ function ChainRulesCore. rrule (
72+ :: typeof (MTK. remake_buffer), indp, oldbuf:: MTK.MTKParameters , idxs, vals)
73+ if idxs isa AbstractSet
74+ idxs = collect (idxs)
75+ end
76+ idxs = map (idxs) do i
77+ i isa MTK. ParameterIndex ? i : MTK. parameter_index (indp, i)
78+ end
79+ newbuf = MTK. remake_buffer (indp, oldbuf, idxs, vals)
80+ tunable_idxs = reduce (
81+ vcat, (idx. idx for idx in idxs if idx. portion isa MTK. SciMLStructures. Tunable))
82+ disc_idxs = subset_idxs (idxs, MTK. SciMLStructures. Discrete (), oldbuf. discrete)
83+ const_idxs = subset_idxs (idxs, MTK. SciMLStructures. Constants (), oldbuf. constant)
84+ nn_idxs = subset_idxs (idxs, MTK. NONNUMERIC_PORTION, oldbuf. nonnumeric)
85+
86+ function remake_buffer_pullback (buf′)
87+ buf′ = unthunk (buf′)
88+ f′ = NoTangent ()
89+ indp′ = NoTangent ()
90+
91+ tunable = selected_tangents (buf′. tunable, tunable_idxs)
92+ discrete = selected_tangents (buf′. discrete, disc_idxs)
93+ constant = selected_tangents (buf′. constant, const_idxs)
94+ nonnumeric = selected_tangents (buf′. nonnumeric, nn_idxs)
95+ oldbuf′ = Tangent {typeof(oldbuf)} (; tunable, discrete, constant, nonnumeric)
96+ idxs′ = NoTangent ()
97+ vals′ = map (i -> MTK. _ducktyped_parameter_values (buf′, i), idxs)
98+ return f′, indp′, oldbuf′, idxs′, vals′
99+ end
100+ newbuf, remake_buffer_pullback
101+ end
102+
15103end
0 commit comments