Skip to content

Commit 21be67e

Browse files
committed
Relaxing some assumptions
1 parent 57a1325 commit 21be67e

File tree

1 file changed

+99
-92
lines changed

1 file changed

+99
-92
lines changed

src/systems/systemstructure.jl

Lines changed: 99 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using DataStructures
44
using SymbolicUtils: istree, operation, arguments, Symbolic
55
using ..ModelingToolkit
66
import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten,
7-
value, InvalidSystemException
7+
value, InvalidSystemException, isdifferential
88
using ..BipartiteGraphs
99
using UnPack
1010
using Setfield
@@ -38,15 +38,14 @@ end
3838
=#
3939

4040
export SystemStructure, initialize_system_structure, find_linear_equations
41-
export diffvars_range, dervars_range, algvars_range
4241
export isdiffvar, isdervar, isalgvar, isdiffeq, isalgeq
43-
export DIFFERENTIAL_VARIABLE, ALGEBRAIC_VARIABLE, DERIVATIVE_VARIABLE
44-
export DIFFERENTIAL_EQUATION, ALGEBRAIC_EQUATION
45-
export vartype, eqtype
4642

47-
struct SystemStructure
48-
dxvar_offset::Int
49-
fullvars::Vector # [diffvars; dervars; algvars]
43+
@enum VariableType::Int8 DIFFERENTIAL_VARIABLE ALGEBRAIC_VARIABLE DERIVATIVE_VARIABLE
44+
45+
Base.@kwdef struct SystemStructure
46+
fullvars::Vector
47+
vartype::Vector{VariableType}
48+
inv_varassoc::Vector{Int}
5049
varassoc::Vector{Int}
5150
algeqs::BitVector
5251
graph::BipartiteGraph{Int,Nothing}
@@ -57,121 +56,95 @@ struct SystemStructure
5756
partitions::Vector{NTuple{4, Vector{Int}}}
5857
end
5958

60-
diffvars_range(s::SystemStructure) = 1:s.dxvar_offset
61-
# TODO: maybe dervars should be in the end.
62-
dervars_range(s::SystemStructure) = s.dxvar_offset+1:2s.dxvar_offset
63-
algvars_range(s::SystemStructure) = 2s.dxvar_offset+1:length(s.fullvars)
64-
65-
isdiffvar(s::SystemStructure, var::Integer) = var in diffvars_range(s)
66-
isdervar(s::SystemStructure, var::Integer) = var in dervars_range(s)
67-
isalgvar(s::SystemStructure, var::Integer) = var in algvars_range(s)
68-
69-
@enum VariableType DIFFERENTIAL_VARIABLE ALGEBRAIC_VARIABLE DERIVATIVE_VARIABLE
59+
isdervar(s::SystemStructure, var::Integer) = s.vartype[var] === DERIVATIVE_VARIABLE
60+
isdiffvar(s::SystemStructure, var::Integer) = s.vartype[var] === DIFFERENTIAL_VARIABLE
61+
isalgvar(s::SystemStructure, var::Integer) = s.vartype[var] === ALGEBRAIC_VARIABLE
7062

71-
function vartype(s::SystemStructure, var::Integer)::VariableType
72-
isdiffvar(s, var) ? DIFFERENTIAL_VARIABLE :
73-
isdervar(s, var) ? DERIVATIVE_VARIABLE :
74-
isalgvar(s, var) ? ALGEBRAIC_VARIABLE : error("Variable $var out of bounds")
75-
end
76-
77-
@enum EquationType DIFFERENTIAL_EQUATION ALGEBRAIC_EQUATION
63+
dervars_range(s::SystemStructure) = Iterators.filter(Base.Fix1(s, isdervar), eachindex(s.vartype))
64+
diffvars_range(s::SystemStructure) = Iterators.filter(Base.Fix1(s, isdiffvar), eachindex(s.vartype))
65+
algvars_range(s::SystemStructure) = Iterators.filter(Base.Fix1(s, isalgeq), eachindex(s.vartype))
7866

7967
isalgeq(s::SystemStructure, eq::Integer) = s.algeqs[eq]
8068
isdiffeq(s::SystemStructure, eq::Integer) = !isalgeq(s, eq)
81-
eqtype(s::SystemStructure, eq::Integer)::EquationType = isalgeq(s, eq) ? ALGEBRAIC_EQUATION : DIFFERENTIAL_EQUATION
8269

8370
function initialize_system_structure(sys)
84-
sys, dxvar_offset, fullvars, varassoc, algeqs, graph = init_graph(flatten(sys))
85-
86-
solvable_graph = BipartiteGraph(nsrcs(graph), ndsts(graph))
87-
88-
@set sys.structure = SystemStructure(
89-
dxvar_offset,
90-
fullvars,
91-
varassoc,
92-
algeqs,
93-
graph,
94-
solvable_graph,
95-
Int[],
96-
Int[],
97-
Vector{Int}[],
98-
NTuple{4, Vector{Int}}[]
99-
)
100-
end
101-
102-
function Base.show(io::IO, s::SystemStructure)
103-
@unpack fullvars, dxvar_offset, solvable_graph, graph = s
104-
algvar_offset = 2dxvar_offset
105-
print(io, "xvars: ")
106-
print(io, fullvars[1:dxvar_offset])
107-
print(io, "\ndxvars: ")
108-
print(io, fullvars[dxvar_offset+1:algvar_offset])
109-
print(io, "\nalgvars: ")
110-
print(io, fullvars[algvar_offset+1:end], '\n')
111-
112-
S = incidence_matrix(graph, Num(Sym{Real}(:×)))
113-
print(io, "Incidence matrix:")
114-
show(io, S)
115-
end
116-
117-
function init_graph(sys)
11871
iv = independent_variable(sys)
11972
eqs = equations(sys)
12073
neqs = length(eqs)
12174
algeqs = trues(neqs)
122-
varsadj = Vector{Any}(undef, neqs)
123-
dervars = OrderedSet()
124-
diffvars = OrderedSet()
75+
dervaridxs = Int[]
76+
var2idx = Dict{Any,Int}()
77+
symbolic_incidence = []
78+
fullvars = []
79+
var_counter = 0
12580

12681
for (i, eq) in enumerate(eqs)
12782
vars = OrderedSet()
12883
vars!(vars, eq)
84+
push!(symbolic_incidence, copy(vars))
12985
isalgeq = true
13086
for var in vars
131-
if istree(var) && operation(var) isa Differential
87+
varidx = get(var2idx, var, 0)
88+
if varidx == 0 # new var
89+
var_counter += 1
90+
push!(fullvars, var)
91+
end
92+
93+
if isdifferential(var)
13294
isalgeq = false
13395
diffvar = arguments(var)[1]
13496
if diffvar isa Differential
13597
throw(InvalidSystemException("The equation [ $eq ] is not first order"))
13698
end
137-
push!(dervars, var)
138-
push!(diffvars, diffvar)
99+
push!(dervaridxs, varidx)
139100
end
140101
end
141102
algeqs[i] = isalgeq
142-
varsadj[i] = vars
143103
end
144104

145-
algvars = setdiff(states(sys), diffvars)
146-
fullvars = [collect(diffvars); collect(dervars); algvars]
105+
diffvars = []
106+
varassoc = zeros(Int, length(fullvars))
107+
inv_varassoc = zeros(Int, length(fullvars))
108+
for dervaridx in dervaridxs
109+
dervar = fullvars[dervaridx]
110+
diffvar = arguments(dervar)[1]
111+
diffvaridx = get(var2idx, diffvar, 0)
112+
if diffvaridx != 0
113+
push!(diffvars, diffvar)
114+
varassoc[diffvaridx] = dervaridx
115+
inv_varassoc[dervaridx] = diffvaridx
116+
end
117+
end
147118

148-
dxvar_offset = length(diffvars)
149-
algvar_offset = 2dxvar_offset
119+
algvars = setdiff(states(sys), diffvars)
120+
for algvar in algvars
121+
# it could be that a variable appeared in the states, but never appeared
122+
# in the equations.
123+
algvaridx = get(var2idx, algvar, 0)
124+
if algvaridx != 0
125+
varassoc[algvaridx] = -1
126+
end
127+
end
150128

151-
nvars = length(fullvars)
152-
idxmap = Dict(fullvars .=> 1:nvars)
129+
neqs, nvars = length(eqs), length(fullvars)
153130
graph = BipartiteGraph(neqs, nvars)
154-
155-
vs = Set()
156-
for (i, eq) in enumerate(eqs)
157-
# TODO: custom vars that handles D(x)
158-
# TODO: add checks here
159-
lhs = eq.lhs
160-
if isdiffeq(eq)
161-
v = lhs
162-
haskey(idxmap, v) && add_edge!(graph, i, idxmap[v])
163-
else
164-
vars!(vs, lhs)
165-
end
166-
vars!(vs, eq.rhs)
167-
for v in vs
168-
haskey(idxmap, v) && add_edge!(graph, i, idxmap[v])
169-
end
170-
empty!(vs)
131+
for (ie, vars) in enumerate(symbolic_incidence), v in vars
132+
jv = var2idx[v]
133+
add_edge!(graph, ie, jv)
171134
end
172135

173-
varassoc = Int[(1:dxvar_offset) .+ dxvar_offset; zeros(Int, length(fullvars) - dxvar_offset)] # variable association list
174-
sys, dxvar_offset, fullvars, varassoc, algeqs, graph
136+
SystemStructure(
137+
fullvars = fullvars,
138+
varassoc = varassoc,
139+
inv_varassoc = inv_varassoc,
140+
algeqs = algeqs,
141+
graph = graph,
142+
solvable_graph = BipartiteGraph(nsrcs(graph), ndsts(graph)),
143+
assign = Int[],
144+
inv_assign = Int[],
145+
scc = Vector{Int}[],
146+
partitions = NTuple{4, Vector{Int}}[],
147+
)
175148
end
176149

177150
function find_linear_equations(sys)
@@ -221,6 +194,40 @@ function find_linear_equations(sys)
221194
is_linear_equations[i] = false
222195
end
223196
end
197+
sys, dxvar_offset, fullvars, varassoc, algeqs, graph = init_graph(flatten(sys))
198+
199+
solvable_graph = BipartiteGraph(nsrcs(graph), ndsts(graph))
200+
201+
@set sys.structure = SystemStructure(
202+
dxvar_offset,
203+
fullvars,
204+
varassoc,
205+
algeqs,
206+
graph,
207+
solvable_graph,
208+
Int[],
209+
Int[],
210+
Vector{Int}[],
211+
NTuple{4, Vector{Int}}[]
212+
)
213+
end
214+
215+
function Base.show(io::IO, s::SystemStructure)
216+
@unpack fullvars, dxvar_offset, solvable_graph, graph = s
217+
algvar_offset = 2dxvar_offset
218+
print(io, "xvars: ")
219+
print(io, fullvars[1:dxvar_offset])
220+
print(io, "\ndxvars: ")
221+
print(io, fullvars[dxvar_offset+1:algvar_offset])
222+
print(io, "\nalgvars: ")
223+
print(io, fullvars[algvar_offset+1:end], '\n')
224+
225+
S = incidence_matrix(graph, Num(Sym{Real}(:×)))
226+
print(io, "Incidence matrix:")
227+
show(io, S)
228+
end
229+
230+
function init_graph(sys)
224231
return is_linear_equations, eadj, cadj
225232
end
226233

0 commit comments

Comments
 (0)