Skip to content

Commit 3678131

Browse files
committed
Fix a few minor bugs
1 parent 21be67e commit 3678131

File tree

4 files changed

+37
-44
lines changed

4 files changed

+37
-44
lines changed

src/structural_transformation/tearing.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,11 @@ function tearing_reassemble(sys; simplify=false)
122122

123123

124124
### update equations
125-
newstates = setdiff([fullvars[diffvars_range(s)]; fullvars[algvars_range(s)]], solvars)
125+
odestats = []
126+
for idx in eachindex(fullvars); isdervar(s, idx) && continue
127+
push!(odestats, fullvars[idx])
128+
end
129+
newstates = setdiff(odestats, solvars)
126130
varidxmap = Dict(newstates .=> 1:length(newstates))
127131
neweqs = Vector{Equation}(undef, ns)
128132
newalgeqs = falses(ns)

src/symutils.jl

Whitespace-only changes.

src/systems/systemstructure.jl

Lines changed: 31 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using DataStructures
44
using SymbolicUtils: istree, operation, arguments, Symbolic
55
using ..ModelingToolkit
66
import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten,
7-
value, InvalidSystemException, isdifferential
7+
value, InvalidSystemException, isdifferential, _iszero, isparameter
88
using ..BipartiteGraphs
99
using UnPack
1010
using Setfield
@@ -39,6 +39,7 @@ end
3939

4040
export SystemStructure, initialize_system_structure, find_linear_equations
4141
export isdiffvar, isdervar, isalgvar, isdiffeq, isalgeq
42+
export dervars_range, diffvars_range, algvars_range
4243

4344
@enum VariableType::Int8 DIFFERENTIAL_VARIABLE ALGEBRAIC_VARIABLE DERIVATIVE_VARIABLE
4445

@@ -68,8 +69,10 @@ isalgeq(s::SystemStructure, eq::Integer) = s.algeqs[eq]
6869
isdiffeq(s::SystemStructure, eq::Integer) = !isalgeq(s, eq)
6970

7071
function initialize_system_structure(sys)
72+
sys = flatten(sys)
73+
7174
iv = independent_variable(sys)
72-
eqs = equations(sys)
75+
eqs = copy(equations(sys))
7376
neqs = length(eqs)
7477
algeqs = trues(neqs)
7578
dervaridxs = Int[]
@@ -81,13 +84,17 @@ function initialize_system_structure(sys)
8184
for (i, eq) in enumerate(eqs)
8285
vars = OrderedSet()
8386
vars!(vars, eq)
84-
push!(symbolic_incidence, copy(vars))
8587
isalgeq = true
88+
statevars = []
8689
for var in vars
87-
varidx = get(var2idx, var, 0)
88-
if varidx == 0 # new var
89-
var_counter += 1
90+
isequal(var, iv) && continue
91+
if isparameter(var) || (istree(var) && isparameter(operation(var)))
92+
continue
93+
end
94+
push!(statevars, var)
95+
varidx = get!(var2idx, var) do
9096
push!(fullvars, var)
97+
var_counter += 1
9198
end
9299

93100
if isdifferential(var)
@@ -99,13 +106,21 @@ function initialize_system_structure(sys)
99106
push!(dervaridxs, varidx)
100107
end
101108
end
109+
push!(symbolic_incidence, copy(statevars))
110+
empty!(statevars)
102111
algeqs[i] = isalgeq
112+
if isalgeq && !_iszero(eq.lhs)
113+
eqs[i] = 0 ~ eq.rhs - eq.lhs
114+
end
103115
end
104116

117+
nvars = length(fullvars)
105118
diffvars = []
106-
varassoc = zeros(Int, length(fullvars))
107-
inv_varassoc = zeros(Int, length(fullvars))
119+
vartype = fill(DIFFERENTIAL_VARIABLE, nvars)
120+
varassoc = zeros(Int, nvars)
121+
inv_varassoc = zeros(Int, nvars)
108122
for dervaridx in dervaridxs
123+
vartype[dervaridx] = DERIVATIVE_VARIABLE
109124
dervar = fullvars[dervaridx]
110125
diffvar = arguments(dervar)[1]
111126
diffvaridx = get(var2idx, diffvar, 0)
@@ -121,20 +136,19 @@ function initialize_system_structure(sys)
121136
# it could be that a variable appeared in the states, but never appeared
122137
# in the equations.
123138
algvaridx = get(var2idx, algvar, 0)
124-
if algvaridx != 0
125-
varassoc[algvaridx] = -1
126-
end
139+
vartype[algvaridx] = ALGEBRAIC_VARIABLE
127140
end
128141

129-
neqs, nvars = length(eqs), length(fullvars)
130142
graph = BipartiteGraph(neqs, nvars)
131143
for (ie, vars) in enumerate(symbolic_incidence), v in vars
132144
jv = var2idx[v]
133145
add_edge!(graph, ie, jv)
134146
end
135147

136-
SystemStructure(
148+
@set! sys.eqs = eqs
149+
@set! sys.structure = SystemStructure(
137150
fullvars = fullvars,
151+
vartype = vartype,
138152
varassoc = varassoc,
139153
inv_varassoc = inv_varassoc,
140154
algeqs = algeqs,
@@ -145,6 +159,7 @@ function initialize_system_structure(sys)
145159
scc = Vector{Int}[],
146160
partitions = NTuple{4, Vector{Int}}[],
147161
)
162+
return sys
148163
end
149164

150165
function find_linear_equations(sys)
@@ -194,41 +209,15 @@ function find_linear_equations(sys)
194209
is_linear_equations[i] = false
195210
end
196211
end
197-
sys, dxvar_offset, fullvars, varassoc, algeqs, graph = init_graph(flatten(sys))
198-
199-
solvable_graph = BipartiteGraph(nsrcs(graph), ndsts(graph))
200-
201-
@set sys.structure = SystemStructure(
202-
dxvar_offset,
203-
fullvars,
204-
varassoc,
205-
algeqs,
206-
graph,
207-
solvable_graph,
208-
Int[],
209-
Int[],
210-
Vector{Int}[],
211-
NTuple{4, Vector{Int}}[]
212-
)
212+
213+
return is_linear_equations, eadj, cadj
213214
end
214215

215216
function Base.show(io::IO, s::SystemStructure)
216-
@unpack fullvars, dxvar_offset, solvable_graph, graph = s
217-
algvar_offset = 2dxvar_offset
218-
print(io, "xvars: ")
219-
print(io, fullvars[1:dxvar_offset])
220-
print(io, "\ndxvars: ")
221-
print(io, fullvars[dxvar_offset+1:algvar_offset])
222-
print(io, "\nalgvars: ")
223-
print(io, fullvars[algvar_offset+1:end], '\n')
224-
217+
@unpack graph = s
225218
S = incidence_matrix(graph, Num(Sym{Real}(:×)))
226219
print(io, "Incidence matrix:")
227220
show(io, S)
228221
end
229222

230-
function init_graph(sys)
231-
return is_linear_equations, eadj, cadj
232-
end
233-
234223
end # module

test/reduction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ eqs1 = [
6767
lorenz = name -> ODESystem(eqs1,t,name=name)
6868
lorenz1 = lorenz(:lorenz1)
6969
ss = ModelingToolkit.get_structure(initialize_system_structure(lorenz1))
70-
@test isequal(ss.fullvars, [x, y, z, D(x), D(y), D(z), F, u])
70+
@test isequal(ss.fullvars, [D(x), F, y, x, D(y), u, z, D(z)])
7171
lorenz2 = lorenz(:lorenz2)
7272

7373
connected = ODESystem([s ~ a + lorenz1.x

0 commit comments

Comments
 (0)