1
- module MTKChainRulesCoreExt
2
-
3
- import ModelingToolkit as MTK
4
- import ChainRulesCore
5
- import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk
6
-
7
- function ChainRulesCore. rrule (:: Type{MTK.MTKParameters} , tunables, args... )
1
+ function ChainRulesCore. rrule (:: Type{MTKParameters} , tunables, args... )
8
2
function mtp_pullback (dt)
9
3
dt = unthunk (dt)
10
4
dtunables = dt isa AbstractArray ? dt : dt. tunable
11
5
(NoTangent (), dtunables[1 : length (tunables)],
12
6
ntuple (_ -> NoTangent (), length (args))... )
13
7
end
14
- MTK . MTKParameters (tunables, args... ), mtp_pullback
8
+ MTKParameters (tunables, args... ), mtp_pullback
15
9
end
16
10
17
11
function subset_idxs (idxs, portion, template)
@@ -70,23 +64,23 @@ function selected_tangents(
70
64
end
71
65
72
66
function ChainRulesCore. rrule (
73
- :: typeof (MTK . remake_buffer), indp, oldbuf:: MTK. MTKParameters , idxs, vals)
67
+ :: typeof (remake_buffer), indp, oldbuf:: MTKParameters , idxs, vals)
74
68
if idxs isa AbstractSet
75
69
idxs = collect (idxs)
76
70
end
77
71
idxs = map (idxs) do i
78
- i isa MTK . ParameterIndex ? i : MTK . parameter_index (indp, i)
72
+ i isa ParameterIndex ? i : parameter_index (indp, i)
79
73
end
80
- newbuf = MTK . remake_buffer (indp, oldbuf, idxs, vals)
74
+ newbuf = remake_buffer (indp, oldbuf, idxs, vals)
81
75
tunable_idxs = reduce (
82
- vcat, (idx. idx for idx in idxs if idx. portion isa MTK . SciMLStructures. Tunable);
76
+ vcat, (idx. idx for idx in idxs if idx. portion isa SciMLStructures. Tunable);
83
77
init = Union{Int, AbstractVector{Int}}[])
84
78
initials_idxs = reduce (
85
- vcat, (idx. idx for idx in idxs if idx. portion isa MTK . SciMLStructures. Initials);
79
+ vcat, (idx. idx for idx in idxs if idx. portion isa SciMLStructures. Initials);
86
80
init = Union{Int, AbstractVector{Int}}[])
87
- disc_idxs = subset_idxs (idxs, MTK . SciMLStructures. Discrete (), oldbuf. discrete)
88
- const_idxs = subset_idxs (idxs, MTK . SciMLStructures. Constants (), oldbuf. constant)
89
- nn_idxs = subset_idxs (idxs, MTK . NONNUMERIC_PORTION, oldbuf. nonnumeric)
81
+ disc_idxs = subset_idxs (idxs, SciMLStructures. Discrete (), oldbuf. discrete)
82
+ const_idxs = subset_idxs (idxs, SciMLStructures. Constants (), oldbuf. constant)
83
+ nn_idxs = subset_idxs (idxs, NONNUMERIC_PORTION, oldbuf. nonnumeric)
90
84
91
85
pullback = let idxs = idxs
92
86
function remake_buffer_pullback (buf′)
@@ -102,13 +96,11 @@ function ChainRulesCore.rrule(
102
96
oldbuf′ = Tangent {typeof(oldbuf)} (;
103
97
tunable, initials, discrete, constant, nonnumeric)
104
98
idxs′ = NoTangent ()
105
- vals′ = map (i -> MTK . _ducktyped_parameter_values (buf′, i), idxs)
99
+ vals′ = map (i -> _ducktyped_parameter_values (buf′, i), idxs)
106
100
return f′, indp′, oldbuf′, idxs′, vals′
107
101
end
108
102
end
109
103
newbuf, pullback
110
104
end
111
105
112
- ChainRulesCore. @non_differentiable Base. getproperty (sys:: MTK.AbstractSystem , x:: Symbol )
113
-
114
- end
106
+ ChainRulesCore. @non_differentiable Base. getproperty (sys:: AbstractSystem , x:: Symbol )
0 commit comments