Skip to content

Commit 1de4cd5

Browse files
refactor: move ChainRulesCoreExt into main package
1 parent a35ae7d commit 1de4cd5

File tree

3 files changed

+17
-22
lines changed

3 files changed

+17
-22
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
99
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
1010
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
11+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1112
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1213
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1314
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -65,7 +66,6 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
6566
[weakdeps]
6667
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
6768
CasADi = "c49709b8-5c63-11e9-2fb2-69db5844192f"
68-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
6969
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
7070
FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac"
7171
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
@@ -74,7 +74,6 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
7474
[extensions]
7575
MTKBifurcationKitExt = "BifurcationKit"
7676
MTKCasADiDynamicOptExt = "CasADi"
77-
MTKChainRulesCoreExt = "ChainRulesCore"
7877
MTKDeepDiffsExt = "DeepDiffs"
7978
MTKFMIExt = "FMI"
8079
MTKInfiniteOptExt = "InfiniteOpt"

src/ModelingToolkit.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, bloc
6262
using OffsetArrays: Origin
6363
import CommonSolve
6464
import EnumX
65+
import ChainRulesCore
66+
import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk
6567

6668
using RuntimeGeneratedFunctions
6769
using RuntimeGeneratedFunctions: drop_expr
@@ -204,6 +206,8 @@ include("structural_transformation/StructuralTransformations.jl")
204206
@reexport using .StructuralTransformations
205207
include("inputoutput.jl")
206208

209+
include("adjoints.jl")
210+
207211
for S in subtypes(ModelingToolkit.AbstractSystem)
208212
S = nameof(S)
209213
@eval convert_system(::Type{<:$S}, sys::$S) = sys

ext/MTKChainRulesCoreExt.jl renamed to src/adjoints.jl

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
1-
module MTKChainRulesCoreExt
2-
3-
import ModelingToolkit as MTK
4-
import ChainRulesCore
5-
import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk
6-
7-
function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
1+
function ChainRulesCore.rrule(::Type{MTKParameters}, tunables, args...)
82
function mtp_pullback(dt)
93
dt = unthunk(dt)
104
dtunables = dt isa AbstractArray ? dt : dt.tunable
115
(NoTangent(), dtunables[1:length(tunables)],
126
ntuple(_ -> NoTangent(), length(args))...)
137
end
14-
MTK.MTKParameters(tunables, args...), mtp_pullback
8+
MTKParameters(tunables, args...), mtp_pullback
159
end
1610

1711
function subset_idxs(idxs, portion, template)
@@ -70,23 +64,23 @@ function selected_tangents(
7064
end
7165

7266
function ChainRulesCore.rrule(
73-
::typeof(MTK.remake_buffer), indp, oldbuf::MTK.MTKParameters, idxs, vals)
67+
::typeof(remake_buffer), indp, oldbuf::MTKParameters, idxs, vals)
7468
if idxs isa AbstractSet
7569
idxs = collect(idxs)
7670
end
7771
idxs = map(idxs) do i
78-
i isa MTK.ParameterIndex ? i : MTK.parameter_index(indp, i)
72+
i isa ParameterIndex ? i : parameter_index(indp, i)
7973
end
80-
newbuf = MTK.remake_buffer(indp, oldbuf, idxs, vals)
74+
newbuf = remake_buffer(indp, oldbuf, idxs, vals)
8175
tunable_idxs = reduce(
82-
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable);
76+
vcat, (idx.idx for idx in idxs if idx.portion isa SciMLStructures.Tunable);
8377
init = Union{Int, AbstractVector{Int}}[])
8478
initials_idxs = reduce(
85-
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Initials);
79+
vcat, (idx.idx for idx in idxs if idx.portion isa SciMLStructures.Initials);
8680
init = Union{Int, AbstractVector{Int}}[])
87-
disc_idxs = subset_idxs(idxs, MTK.SciMLStructures.Discrete(), oldbuf.discrete)
88-
const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant)
89-
nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric)
81+
disc_idxs = subset_idxs(idxs, SciMLStructures.Discrete(), oldbuf.discrete)
82+
const_idxs = subset_idxs(idxs, SciMLStructures.Constants(), oldbuf.constant)
83+
nn_idxs = subset_idxs(idxs, NONNUMERIC_PORTION, oldbuf.nonnumeric)
9084

9185
pullback = let idxs = idxs
9286
function remake_buffer_pullback(buf′)
@@ -102,13 +96,11 @@ function ChainRulesCore.rrule(
10296
oldbuf′ = Tangent{typeof(oldbuf)}(;
10397
tunable, initials, discrete, constant, nonnumeric)
10498
idxs′ = NoTangent()
105-
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
99+
vals′ = map(i -> _ducktyped_parameter_values(buf′, i), idxs)
106100
return f′, indp′, oldbuf′, idxs′, vals′
107101
end
108102
end
109103
newbuf, pullback
110104
end
111105

112-
ChainRulesCore.@non_differentiable Base.getproperty(sys::MTK.AbstractSystem, x::Symbol)
113-
114-
end
106+
ChainRulesCore.@non_differentiable Base.getproperty(sys::AbstractSystem, x::Symbol)

0 commit comments

Comments
 (0)