Skip to content

Commit ecdcdb1

Browse files
committed
Add SystemStructure
1 parent f0c2a83 commit ecdcdb1

File tree

4 files changed

+157
-1
lines changed

4 files changed

+157
-1
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ include("domains.jl")
183183
include("register_function.jl")
184184

185185
include("systems/abstractsystem.jl")
186+
include("systems/systemstructure.jl")
186187

187188
include("systems/diffeqs/odesystem.jl")
188189
include("systems/diffeqs/sdesystem.jl")
@@ -213,6 +214,7 @@ include("extra_functions.jl")
213214

214215
export ODESystem, ODEFunction, ODEFunctionExpr, ODEProblemExpr
215216
export SDESystem, SDEFunction, SDEFunctionExpr, SDESystemExpr
217+
export SystemStructure
216218
export JumpSystem
217219
export ODEProblem, SDEProblem
218220
export NonlinearProblem, NonlinearProblemExpr

src/systems/abstractsystem.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,4 +297,3 @@ function (f::AbstractSysToExpr)(O)
297297
end
298298
return build_expr(:call, Any[operation(O); f.(arguments(O))])
299299
end
300-

src/systems/systemstructure.jl

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
using SparseArrays
2+
3+
const SHOW_EQUATIONS = Ref(false)
4+
struct SystemStructure
5+
sys
6+
dxvar_offset::Int
7+
fullvars::Vector # [xvar; dxvars; algvars]
8+
varassoc::Vector{Int}
9+
graph::BipartiteGraph{Int}
10+
solvable_graph::BipartiteGraph{Int}
11+
end
12+
function SystemStructure(sys)
13+
sys = ModelingToolkit.flatten(sys)
14+
sys, dxvar_offset, fullvars, varassoc, graph, solvable_graph = init_graph(sys)
15+
SystemStructure(sys, dxvar_offset, fullvars, varassoc, graph, solvable_graph)
16+
end
17+
18+
ModelingToolkit.equations(s::SystemStructure) = equations(s.sys)
19+
20+
function Base.show(io::IO, s::SystemStructure)
21+
@unpack fullvars, dxvar_offset, solvable_graph, graph = s
22+
algvar_offset = 2dxvar_offset
23+
print(io, "xvars: ")
24+
print(io, fullvars[1:dxvar_offset])
25+
print(io, "\ndxvars: ")
26+
print(io, fullvars[dxvar_offset+1:algvar_offset])
27+
print(io, "\nalgvars: ")
28+
print(io, fullvars[algvar_offset+1:end], '\n')
29+
30+
if SHOW_EQUATIONS[]
31+
println(io, "Edges:")
32+
eqs = equations(s)
33+
for ev in 𝑠vertices(graph)
34+
print(io, " $(eqs[ev])\n -> ")
35+
vars = 𝑠neighbors(graph, ev)
36+
solvars = 𝑠neighbors(solvable_graph, ev)
37+
solvable = intersect(vars, solvars)
38+
notsolvable = setdiff(vars, solvars)
39+
40+
print(io, join(string.(fullvars[notsolvable]), ", "))
41+
for ii in solvable
42+
print(io, ", ")
43+
var = fullvars[ii]
44+
Base.printstyled(io, string(fullvars[ii]), color=:cyan)
45+
end
46+
println(io)
47+
end
48+
end
49+
50+
S = incidence_matrix(graph, Num(Sym{Real}(:×)))
51+
print(io, "Incidence matrix:")
52+
show(io, S)
53+
end
54+
55+
# V-nodes `[x_1, x_2, x_3, ..., dx_1, dx_2, ..., y_1, y_2, ...]` where `x`s are
56+
# differential variables and `y`s are algebraic variables.
57+
function get_vnodes(sys)
58+
dxvars = []
59+
eqs = equations(sys)
60+
for (i, eq) in enumerate(eqs)
61+
if eq.lhs isa Symbolic
62+
# Make sure that the LHS is a first order derivative of a var.
63+
@assert operation(eq.lhs) isa Differential "The equation $eq is not in the form of `D(...) ~ ...`"
64+
@assert !(arguments(eq.lhs)[1] isa Differential) "The equation $eq is not first order"
65+
66+
push!(dxvars, eq.lhs)
67+
end
68+
end
69+
70+
xvars = (first var_from_nested_derivative).(dxvars)
71+
algvars = setdiff(states(sys), xvars)
72+
return xvars, dxvars, algvars
73+
end
74+
75+
function init_graph(sys)
76+
xvars, dxvars, algvars = get_vnodes(sys)
77+
dxvar_offset = length(xvars)
78+
algvar_offset = 2dxvar_offset
79+
80+
fullvars = [xvars; dxvars; algvars]
81+
sys = reordersys(sys, dxvar_offset, fullvars)
82+
eqs = equations(sys)
83+
idxmap = Dict(fullvars .=> 1:length(fullvars))
84+
graph = BipartiteGraph(length(eqs), length(fullvars))
85+
solvable_graph = BipartiteGraph(length(eqs), length(fullvars))
86+
87+
for (i, eq) in enumerate(eqs)
88+
if isdiffeq(eq)
89+
v = eq.lhs
90+
haskey(idxmap, v) && add_edge!(graph, i, idxmap[v])
91+
end
92+
# TODO: custom vars that handles D(x)
93+
vs = vars(eq.rhs)
94+
for v in vs
95+
haskey(idxmap, v) && add_edge!(graph, i, idxmap[v])
96+
end
97+
end
98+
99+
varassoc = Int[(1:dxvar_offset) .+ dxvar_offset; zeros(Int, length(fullvars) - dxvar_offset)] # variable association list
100+
sys, dxvar_offset, fullvars, varassoc, graph, solvable_graph
101+
end
102+
103+
function reordersys(sys, dxvar_offset, fullvars)
104+
eqs = equations(sys)
105+
neweqs = Vector{Equation}(undef, length(eqs))
106+
eqidxmap = Dict(@view(fullvars[dxvar_offset+1:2dxvar_offset]) .=> (1:dxvar_offset))
107+
varidxmap = Dict([@view(fullvars[1:dxvar_offset]); @view(fullvars[2dxvar_offset+1:end])] .=> (1:length(fullvars)-dxvar_offset))
108+
algidx = dxvar_offset
109+
for eq in eqs
110+
if isdiffeq(eq)
111+
neweqs[eqidxmap[eq.lhs]] = eq
112+
else
113+
neweqs[algidx+=1] = eq
114+
end
115+
end
116+
sts = states(sys)
117+
@set! sys.eqs = neweqs
118+
@set! sys.states = sts[map(s->varidxmap[s], sts)]
119+
end

test/dep_graphs.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using Test
12
using ModelingToolkit, LightGraphs
23

34
import ModelingToolkit: value
@@ -101,3 +102,38 @@ ns = NonlinearSystem(eqs, [x,y,z],[σ,ρ,β])
101102
deps = equation_dependencies(ns)
102103
eq_sdeps = [[x,y],[y],[y,z]]
103104
@test all(i -> isequal(Set(deps[i]),Set(value.(eq_sdeps[i]))), 1:length(deps))
105+
106+
using SparseArrays
107+
using ModelingToolkit
108+
using UnPack
109+
110+
# Define some variables
111+
@parameters t L g
112+
@variables x(t) y(t) w(t) z(t) T(t)
113+
@derivatives D'~t
114+
115+
# Simple pendulum in cartesian coordinates
116+
eqs = [D(x) ~ w,
117+
D(y) ~ z,
118+
D(w) ~ T*x,
119+
D(z) ~ T*y - g,
120+
0 ~ x^2 + y^2 - L^2]
121+
pendulum = ODESystem(eqs, t, [x, y, w, z, T], [L, g], name=:pendulum)
122+
sss = SystemStructure(pendulum)
123+
@unpack graph, fullvars, varassoc = sss
124+
@test isequal(fullvars, [x, y, w, z, D(x), D(y), D(w), D(z), T])
125+
@test graph.fadjlist == sort.([[5, 3], [6, 4], [7, 1, 9], [8, 2, 9], [2, 1]])
126+
@test graph.badjlist == sort.([[3, 5], [4, 5], [1], [2], [1], [2], [3], [4], [3, 4]])
127+
@test LightGraphs.ne(graph) == nnz(incidence_matrix(graph)) == 12
128+
@test varassoc == [5, 6, 7, 8, 0, 0, 0, 0, 0]
129+
130+
se = collect(ModelingToolkit.𝑠edges(graph))
131+
@test se == mapreduce(vcat, enumerate(graph.fadjlist)) do (s, d)
132+
ModelingToolkit.BipartiteEdge.(s, d)
133+
end
134+
de = collect(ModelingToolkit.𝑑edges(graph))
135+
@test de == mapreduce(vcat, enumerate(graph.badjlist)) do (d, s)
136+
ModelingToolkit.BipartiteEdge.(s, d)
137+
end
138+
ae = collect(ModelingToolkit.edges(graph))
139+
@test ae == vcat(se, de)

0 commit comments

Comments
 (0)