Skip to content

Commit da5fa87

Browse files
authored
Merge pull request #2257 from SciML/myb/sub_sys
Implement `substitute` for `AbstractSystem`s
2 parents cc7a102 + 5981d75 commit da5fa87

File tree

4 files changed

+38
-8
lines changed

4 files changed

+38
-8
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ $(DocStringExtensions.README)
33
"""
44
module ModelingToolkit
55
using PrecompileTools, Reexport
6-
@recompile_invalidations begin
6+
@recompile_invalidations begin
77
using DocStringExtensions
88
using Compat
99
using AbstractTrees

src/systems/abstractsystem.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1831,3 +1831,21 @@ function missing_variable_defaults(sys::AbstractSystem, default = 0.0)
18311831

18321832
return ds
18331833
end
1834+
1835+
keytype(::Type{<:Pair{T, V}}) where {T, V} = T
1836+
function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair}, Dict})
1837+
if keytype(eltype(rules)) <: Symbol
1838+
dict = todict(rules)
1839+
systems = get_systems(sys)
1840+
# post-walk to avoid infinite recursion
1841+
@set! sys.systems = map(Base.Fix2(substitute, dict), systems)
1842+
something(get(rules, nameof(sys), nothing), sys)
1843+
elseif sys isa ODESystem
1844+
rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]),
1845+
collect(rules)))
1846+
eqs = fast_substitute(equations(sys), rules)
1847+
ODESystem(eqs, get_iv(sys); name = nameof(sys))
1848+
else
1849+
error("substituting symbols is not supported for $(typeof(sys))")
1850+
end
1851+
end

src/systems/diffeqs/odesystem.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -465,13 +465,6 @@ function convert_system(::Type{<:ODESystem}, sys, t; name = nameof(sys))
465465
checks = false)
466466
end
467467

468-
function Symbolics.substitute(sys::ODESystem, rules::Union{Vector{<:Pair}, Dict})
469-
rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]),
470-
collect(rules)))
471-
eqs = fast_substitute(equations(sys), rules)
472-
ODESystem(eqs, get_iv(sys); name = nameof(sys))
473-
end
474-
475468
"""
476469
$(SIGNATURES)
477470

test/odesystem.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,3 +1012,22 @@ let
10121012
prob = ODAEProblem(sys4s, [x => 1.0, D(x) => 1.0], (0, 1.0))
10131013
@test !isnothing(prob.f.sys)
10141014
end
1015+
1016+
@parameters t
1017+
# SYS 1:
1018+
vars_sub1 = @variables s1(t)
1019+
@named sub = ODESystem(Equation[], t, vars_sub1, [])
1020+
1021+
vars1 = @variables x1(t)
1022+
@named sys1 = ODESystem(Equation[], t, vars1, [], systems = [sub])
1023+
@named sys2 = ODESystem(Equation[], t, vars1, [], systems = [sys1, sub])
1024+
1025+
# SYS 2: Extension to SYS 1
1026+
vars_sub2 = @variables s2(t)
1027+
@named partial_sub = ODESystem(Equation[], t, vars_sub2, [])
1028+
@named sub = extend(partial_sub, sub)
1029+
1030+
new_sys2 = complete(substitute(sys2, Dict(:sub => sub)))
1031+
Set(states(new_sys2)) == Set([new_sys2.x1, new_sys2.sys1.x1,
1032+
new_sys2.sys1.sub.s1, new_sys2.sys1.sub.s2,
1033+
new_sys2.sub.s1, new_sys2.sub.s2])

0 commit comments

Comments
 (0)