Skip to content

Commit 38c5b42

Browse files
authored
Merge pull request #1357 from SciML/myb/fixcompact
Fix `compact_graph!` in tearing
2 parents ecf7912 + 7b8e0ef commit 38c5b42

File tree

3 files changed

+188
-181
lines changed

3 files changed

+188
-181
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ include("utils.jl")
4444
include("pantelides.jl")
4545
include("bipartite_tearing/modia_tearing.jl")
4646
include("tearing.jl")
47+
include("symbolics_tearing.jl")
4748
include("codegen.jl")
4849

4950
end # module
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
function tearing_sub(expr, dict, s)
2+
expr = ModelingToolkit.fixpoint_sub(expr, dict)
3+
s ? simplify(expr) : expr
4+
end
5+
6+
function tearing_reassemble(sys; simplify=false)
7+
s = structure(sys)
8+
@unpack fullvars, partitions, var_eq_matching, graph, scc = s
9+
eqs = equations(sys)
10+
11+
### extract partition information
12+
rhss = []
13+
solvars = []
14+
ns, nd = nsrcs(graph), ndsts(graph)
15+
active_eqs = trues(ns)
16+
active_vars = trues(nd)
17+
rvar2reqs = Vector{Vector{Int}}(undef, nd)
18+
for (ith_scc, partition) in enumerate(partitions)
19+
@unpack e_solved, v_solved, e_residual, v_residual = partition
20+
for ii in eachindex(e_solved)
21+
ieq = e_solved[ii]; ns -= 1
22+
iv = v_solved[ii]; nd -= 1
23+
rvar2reqs[iv] = e_solved
24+
25+
active_eqs[ieq] = false
26+
active_vars[iv] = false
27+
28+
eq = eqs[ieq]
29+
var = fullvars[iv]
30+
rhs = value(solve_for(eq, var; simplify=simplify, check=false))
31+
# if we don't simplify the rhs and the `eq` is not solved properly
32+
(!simplify && occursin(rhs, var)) && (rhs = SymbolicUtils.polynormalize(rhs))
33+
# Since we know `eq` is linear wrt `var`, so the round off must be a
34+
# linear term. We can correct the round off error by a linear
35+
# correction.
36+
rhs -= expand_derivatives(Differential(var)(rhs))*var
37+
@assert !(var in vars(rhs)) """
38+
When solving
39+
$eq
40+
$var remainded in
41+
$rhs.
42+
"""
43+
push!(rhss, rhs)
44+
push!(solvars, var)
45+
end
46+
# DEBUG:
47+
#@show ith_scc solvars .~ rhss
48+
#Main._nlsys[] = eqs[e_solved], fullvars[v_solved]
49+
#ModelingToolkit.topsort_equations(solvars .~ rhss, fullvars)
50+
#empty!(solvars); empty!(rhss)
51+
end
52+
53+
### update SCC
54+
eq_reidx = Vector{Int}(undef, nsrcs(graph))
55+
idx = 0
56+
for (i, active) in enumerate(active_eqs)
57+
eq_reidx[i] = active ? (idx += 1) : -1
58+
end
59+
60+
rmidxs = Int[]
61+
newscc = Vector{Int}[]; sizehint!(newscc, length(scc))
62+
for component′ in newscc
63+
component = copy(component′)
64+
for (idx, eq) in enumerate(component)
65+
if active_eqs[eq]
66+
component[idx] = eq_reidx[eq]
67+
else
68+
push!(rmidxs, idx)
69+
end
70+
end
71+
push!(newscc, component)
72+
deleteat!(component, rmidxs)
73+
empty!(rmidxs)
74+
end
75+
76+
### update graph
77+
var_reidx = Vector{Int}(undef, ndsts(graph))
78+
idx = 0
79+
for (i, active) in enumerate(active_vars)
80+
var_reidx[i] = active ? (idx += 1) : -1
81+
end
82+
83+
newgraph = BipartiteGraph(ns, nd, Val(false))
84+
85+
86+
### update equations
87+
odestats = []
88+
for idx in eachindex(fullvars); isdervar(s, idx) && continue
89+
push!(odestats, fullvars[idx])
90+
end
91+
newstates = setdiff(odestats, solvars)
92+
varidxmap = Dict(newstates .=> 1:length(newstates))
93+
neweqs = Vector{Equation}(undef, ns)
94+
newalgeqs = falses(ns)
95+
96+
dict = Dict(value.(solvars) .=> value.(rhss))
97+
98+
visited = falses(ndsts(graph))
99+
for ieq in Iterators.flatten(scc); active_eqs[ieq] || continue
100+
eq = eqs[ieq]
101+
ridx = eq_reidx[ieq]
102+
103+
fill!(visited, false)
104+
compact_graph!(newgraph, graph, visited, ieq, ridx, rvar2reqs, var_reidx, active_vars)
105+
106+
if isdiffeq(eq)
107+
neweqs[ridx] = eq.lhs ~ tearing_sub(eq.rhs, dict, simplify)
108+
else
109+
newalgeqs[ridx] = true
110+
if !(eq.lhs isa Number && eq.lhs != 0)
111+
eq = 0 ~ eq.rhs - eq.lhs
112+
end
113+
rhs = tearing_sub(eq.rhs, dict, simplify)
114+
if rhs isa Symbolic
115+
neweqs[ridx] = 0 ~ rhs
116+
else # a number
117+
if abs(rhs) > 100eps(float(rhs))
118+
@warn "The equation $eq is not consistent. It simplifed to 0 == $rhs."
119+
end
120+
neweqs[ridx] = 0 ~ fullvars[invview(var_eq_matching)[ieq]]
121+
end
122+
end
123+
end
124+
125+
### update partitions
126+
newpartitions = similar(partitions, 0)
127+
emptyintvec = Int[]
128+
for (ii, partition) in enumerate(partitions)
129+
@unpack e_residual, v_residual = partition
130+
isempty(v_residual) && continue
131+
new_e_residual = similar(e_residual)
132+
new_v_residual = similar(v_residual)
133+
for ii in eachindex(e_residual)
134+
new_e_residual[ii] = eq_reidx[ e_residual[ii]]
135+
new_v_residual[ii] = var_reidx[v_residual[ii]]
136+
end
137+
# `emptyintvec` is aliased to save memory
138+
# We need them for type stability
139+
newpart = SystemPartition(emptyintvec, emptyintvec, new_e_residual, new_v_residual)
140+
push!(newpartitions, newpart)
141+
end
142+
143+
obseqs = solvars .~ rhss
144+
145+
@set! s.graph = newgraph
146+
@set! s.scc = newscc
147+
@set! s.fullvars = fullvars[active_vars]
148+
@set! s.vartype = s.vartype[active_vars]
149+
@set! s.partitions = newpartitions
150+
@set! s.algeqs = newalgeqs
151+
152+
@set! sys.structure = s
153+
@set! sys.eqs = neweqs
154+
@set! sys.states = newstates
155+
@set! sys.observed = [observed(sys); obseqs]
156+
return sys
157+
end
158+
159+
# removes the solved equations and variables
160+
function compact_graph!(newgraph, graph, visited, eq, req, rvar2reqs, var_reidx, active_vars)
161+
for ivar in 𝑠neighbors(graph, eq)
162+
# Note that we need to check `ii` against the rhs states to make
163+
# sure we don't run in circles.
164+
visited[ivar] && continue
165+
visited[ivar] = true
166+
167+
if active_vars[ivar]
168+
add_edge!(newgraph, req, var_reidx[ivar])
169+
else
170+
# If a state is reduced, then we go to the rhs and collect
171+
# its states.
172+
for ieq in rvar2reqs[ivar]
173+
compact_graph!(newgraph, graph, visited, ieq, req, rvar2reqs, var_reidx, active_vars)
174+
end
175+
end
176+
end
177+
return nothing
178+
end
179+
180+
"""
181+
tearing(sys; simplify=false)
182+
183+
Tear the nonlinear equations in system. When `simplify=true`, we simplify the
184+
new residual residual equations after tearing. End users are encouraged to call [`structural_simplify`](@ref)
185+
instead, which calls this function internally.
186+
"""
187+
tearing(sys; simplify=false) = tearing_reassemble(tear_graph(algebraic_equations_scc(sys)); simplify=simplify)

src/structural_transformation/tearing.jl

Lines changed: 0 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -19,178 +19,6 @@ function tear_graph(sys)
1919
return sys
2020
end
2121

22-
function tearing_sub(expr, dict, s)
23-
expr = ModelingToolkit.fixpoint_sub(expr, dict)
24-
s ? simplify(expr) : expr
25-
end
26-
27-
function tearing_reassemble(sys; simplify=false)
28-
s = structure(sys)
29-
@unpack fullvars, partitions, var_eq_matching, graph, scc = s
30-
eqs = equations(sys)
31-
32-
### extract partition information
33-
rhss = []
34-
solvars = []
35-
ns, nd = nsrcs(graph), ndsts(graph)
36-
active_eqs = trues(ns)
37-
active_vars = trues(nd)
38-
rvar2req = Vector{Int}(undef, nd)
39-
for (ith_scc, partition) in enumerate(partitions)
40-
@unpack e_solved, v_solved, e_residual, v_residual = partition
41-
for ii in eachindex(e_solved)
42-
ieq = e_solved[ii]; ns -= 1
43-
iv = v_solved[ii]; nd -= 1
44-
rvar2req[iv] = ieq
45-
46-
active_eqs[ieq] = false
47-
active_vars[iv] = false
48-
49-
eq = eqs[ieq]
50-
var = fullvars[iv]
51-
rhs = value(solve_for(eq, var; simplify=simplify, check=false))
52-
# if we don't simplify the rhs and the `eq` is not solved properly
53-
(!simplify && occursin(rhs, var)) && (rhs = SymbolicUtils.polynormalize(rhs))
54-
# Since we know `eq` is linear wrt `var`, so the round off must be a
55-
# linear term. We can correct the round off error by a linear
56-
# correction.
57-
rhs -= expand_derivatives(Differential(var)(rhs))*var
58-
@assert !(var in vars(rhs)) """
59-
When solving
60-
$eq
61-
$var remainded in
62-
$rhs.
63-
"""
64-
push!(rhss, rhs)
65-
push!(solvars, var)
66-
end
67-
# DEBUG:
68-
#@show ith_scc solvars .~ rhss
69-
#Main._nlsys[] = eqs[e_solved], fullvars[v_solved]
70-
#ModelingToolkit.topsort_equations(solvars .~ rhss, fullvars)
71-
#empty!(solvars); empty!(rhss)
72-
end
73-
74-
### update SCC
75-
eq_reidx = Vector{Int}(undef, nsrcs(graph))
76-
idx = 0
77-
for (i, active) in enumerate(active_eqs)
78-
eq_reidx[i] = active ? (idx += 1) : -1
79-
end
80-
81-
rmidxs = Int[]
82-
newscc = Vector{Int}[]; sizehint!(newscc, length(scc))
83-
for component′ in newscc
84-
component = copy(component′)
85-
for (idx, eq) in enumerate(component)
86-
if active_eqs[eq]
87-
component[idx] = eq_reidx[eq]
88-
else
89-
push!(rmidxs, idx)
90-
end
91-
end
92-
push!(newscc, component)
93-
deleteat!(component, rmidxs)
94-
empty!(rmidxs)
95-
end
96-
97-
### update graph
98-
var_reidx = Vector{Int}(undef, ndsts(graph))
99-
idx = 0
100-
for (i, active) in enumerate(active_vars)
101-
var_reidx[i] = active ? (idx += 1) : -1
102-
end
103-
104-
newgraph = BipartiteGraph(ns, nd, Val(false))
105-
106-
function visit!(ii, gidx, basecase=true)
107-
ieq = basecase ? ii : rvar2req[ii]
108-
for ivar in 𝑠neighbors(graph, ieq)
109-
# Note that we need to check `ii` against the rhs states to make
110-
# sure we don't run in circles.
111-
(!basecase && ivar === ii) && continue
112-
if active_vars[ivar]
113-
add_edge!(newgraph, gidx, var_reidx[ivar])
114-
else
115-
# If a state is reduced, then we go to the rhs and collect
116-
# its states.
117-
visit!(ivar, gidx, false)
118-
end
119-
end
120-
return nothing
121-
end
122-
123-
### update equations
124-
odestats = []
125-
for idx in eachindex(fullvars); isdervar(s, idx) && continue
126-
push!(odestats, fullvars[idx])
127-
end
128-
newstates = setdiff(odestats, solvars)
129-
varidxmap = Dict(newstates .=> 1:length(newstates))
130-
neweqs = Vector{Equation}(undef, ns)
131-
newalgeqs = falses(ns)
132-
133-
dict = Dict(value.(solvars) .=> value.(rhss))
134-
135-
for ieq in Iterators.flatten(scc); active_eqs[ieq] || continue
136-
eq = eqs[ieq]
137-
ridx = eq_reidx[ieq]
138-
139-
visit!(ieq, ridx)
140-
141-
if isdiffeq(eq)
142-
neweqs[ridx] = eq.lhs ~ tearing_sub(eq.rhs, dict, simplify)
143-
else
144-
newalgeqs[ridx] = true
145-
if !(eq.lhs isa Number && eq.lhs != 0)
146-
eq = 0 ~ eq.rhs - eq.lhs
147-
end
148-
rhs = tearing_sub(eq.rhs, dict, simplify)
149-
if rhs isa Symbolic
150-
neweqs[ridx] = 0 ~ rhs
151-
else # a number
152-
if abs(rhs) > 100eps(float(rhs))
153-
@warn "The equation $eq is not consistent. It simplifed to 0 == $rhs."
154-
end
155-
neweqs[ridx] = 0 ~ fullvars[invview(var_eq_matching)[ieq]]
156-
end
157-
end
158-
end
159-
160-
### update partitions
161-
newpartitions = similar(partitions, 0)
162-
emptyintvec = Int[]
163-
for (ii, partition) in enumerate(partitions)
164-
@unpack e_residual, v_residual = partition
165-
isempty(v_residual) && continue
166-
new_e_residual = similar(e_residual)
167-
new_v_residual = similar(v_residual)
168-
for ii in eachindex(e_residual)
169-
new_e_residual[ii] = eq_reidx[ e_residual[ii]]
170-
new_v_residual[ii] = var_reidx[v_residual[ii]]
171-
end
172-
# `emptyintvec` is aliased to save memory
173-
# We need them for type stability
174-
newpart = SystemPartition(emptyintvec, emptyintvec, new_e_residual, new_v_residual)
175-
push!(newpartitions, newpart)
176-
end
177-
178-
obseqs = solvars .~ rhss
179-
180-
@set! s.graph = newgraph
181-
@set! s.scc = newscc
182-
@set! s.fullvars = fullvars[active_vars]
183-
@set! s.vartype = s.vartype[active_vars]
184-
@set! s.partitions = newpartitions
185-
@set! s.algeqs = newalgeqs
186-
187-
@set! sys.structure = s
188-
@set! sys.eqs = neweqs
189-
@set! sys.states = newstates
190-
@set! sys.observed = [observed(sys); obseqs]
191-
return sys
192-
end
193-
19422
"""
19523
algebraic_equations_scc(sys)
19624
@@ -212,12 +40,3 @@ function algebraic_equations_scc(sys)
21240
@set! sys.structure.scc = components
21341
return sys
21442
end
215-
216-
"""
217-
tearing(sys; simplify=false)
218-
219-
Tear the nonlinear equations in system. When `simplify=true`, we simplify the
220-
new residual residual equations after tearing. End users are encouraged to call [`structural_simplify`](@ref)
221-
instead, which calls this function internally.
222-
"""
223-
tearing(sys; simplify=false) = tearing_reassemble(tear_graph(algebraic_equations_scc(sys)); simplify=simplify)

0 commit comments

Comments
 (0)