Skip to content

Commit 7b8e0ef

Browse files
committed
Move file around
1 parent 15d1cd8 commit 7b8e0ef

File tree

3 files changed

+188
-189
lines changed

3 files changed

+188
-189
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ include("utils.jl")
4444
include("pantelides.jl")
4545
include("bipartite_tearing/modia_tearing.jl")
4646
include("tearing.jl")
47+
include("symbolics_tearing.jl")
4748
include("codegen.jl")
4849

4950
end # module
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
function tearing_sub(expr, dict, s)
2+
expr = ModelingToolkit.fixpoint_sub(expr, dict)
3+
s ? simplify(expr) : expr
4+
end
5+
6+
function tearing_reassemble(sys; simplify=false)
7+
s = structure(sys)
8+
@unpack fullvars, partitions, var_eq_matching, graph, scc = s
9+
eqs = equations(sys)
10+
11+
### extract partition information
12+
rhss = []
13+
solvars = []
14+
ns, nd = nsrcs(graph), ndsts(graph)
15+
active_eqs = trues(ns)
16+
active_vars = trues(nd)
17+
rvar2reqs = Vector{Vector{Int}}(undef, nd)
18+
for (ith_scc, partition) in enumerate(partitions)
19+
@unpack e_solved, v_solved, e_residual, v_residual = partition
20+
for ii in eachindex(e_solved)
21+
ieq = e_solved[ii]; ns -= 1
22+
iv = v_solved[ii]; nd -= 1
23+
rvar2reqs[iv] = e_solved
24+
25+
active_eqs[ieq] = false
26+
active_vars[iv] = false
27+
28+
eq = eqs[ieq]
29+
var = fullvars[iv]
30+
rhs = value(solve_for(eq, var; simplify=simplify, check=false))
31+
# if we don't simplify the rhs and the `eq` is not solved properly
32+
(!simplify && occursin(rhs, var)) && (rhs = SymbolicUtils.polynormalize(rhs))
33+
# Since we know `eq` is linear wrt `var`, so the round off must be a
34+
# linear term. We can correct the round off error by a linear
35+
# correction.
36+
rhs -= expand_derivatives(Differential(var)(rhs))*var
37+
@assert !(var in vars(rhs)) """
38+
When solving
39+
$eq
40+
$var remainded in
41+
$rhs.
42+
"""
43+
push!(rhss, rhs)
44+
push!(solvars, var)
45+
end
46+
# DEBUG:
47+
#@show ith_scc solvars .~ rhss
48+
#Main._nlsys[] = eqs[e_solved], fullvars[v_solved]
49+
#ModelingToolkit.topsort_equations(solvars .~ rhss, fullvars)
50+
#empty!(solvars); empty!(rhss)
51+
end
52+
53+
### update SCC
54+
eq_reidx = Vector{Int}(undef, nsrcs(graph))
55+
idx = 0
56+
for (i, active) in enumerate(active_eqs)
57+
eq_reidx[i] = active ? (idx += 1) : -1
58+
end
59+
60+
rmidxs = Int[]
61+
newscc = Vector{Int}[]; sizehint!(newscc, length(scc))
62+
for component′ in newscc
63+
component = copy(component′)
64+
for (idx, eq) in enumerate(component)
65+
if active_eqs[eq]
66+
component[idx] = eq_reidx[eq]
67+
else
68+
push!(rmidxs, idx)
69+
end
70+
end
71+
push!(newscc, component)
72+
deleteat!(component, rmidxs)
73+
empty!(rmidxs)
74+
end
75+
76+
### update graph
77+
var_reidx = Vector{Int}(undef, ndsts(graph))
78+
idx = 0
79+
for (i, active) in enumerate(active_vars)
80+
var_reidx[i] = active ? (idx += 1) : -1
81+
end
82+
83+
newgraph = BipartiteGraph(ns, nd, Val(false))
84+
85+
86+
### update equations
87+
odestats = []
88+
for idx in eachindex(fullvars); isdervar(s, idx) && continue
89+
push!(odestats, fullvars[idx])
90+
end
91+
newstates = setdiff(odestats, solvars)
92+
varidxmap = Dict(newstates .=> 1:length(newstates))
93+
neweqs = Vector{Equation}(undef, ns)
94+
newalgeqs = falses(ns)
95+
96+
dict = Dict(value.(solvars) .=> value.(rhss))
97+
98+
visited = falses(ndsts(graph))
99+
for ieq in Iterators.flatten(scc); active_eqs[ieq] || continue
100+
eq = eqs[ieq]
101+
ridx = eq_reidx[ieq]
102+
103+
fill!(visited, false)
104+
compact_graph!(newgraph, graph, visited, ieq, ridx, rvar2reqs, var_reidx, active_vars)
105+
106+
if isdiffeq(eq)
107+
neweqs[ridx] = eq.lhs ~ tearing_sub(eq.rhs, dict, simplify)
108+
else
109+
newalgeqs[ridx] = true
110+
if !(eq.lhs isa Number && eq.lhs != 0)
111+
eq = 0 ~ eq.rhs - eq.lhs
112+
end
113+
rhs = tearing_sub(eq.rhs, dict, simplify)
114+
if rhs isa Symbolic
115+
neweqs[ridx] = 0 ~ rhs
116+
else # a number
117+
if abs(rhs) > 100eps(float(rhs))
118+
@warn "The equation $eq is not consistent. It simplifed to 0 == $rhs."
119+
end
120+
neweqs[ridx] = 0 ~ fullvars[invview(var_eq_matching)[ieq]]
121+
end
122+
end
123+
end
124+
125+
### update partitions
126+
newpartitions = similar(partitions, 0)
127+
emptyintvec = Int[]
128+
for (ii, partition) in enumerate(partitions)
129+
@unpack e_residual, v_residual = partition
130+
isempty(v_residual) && continue
131+
new_e_residual = similar(e_residual)
132+
new_v_residual = similar(v_residual)
133+
for ii in eachindex(e_residual)
134+
new_e_residual[ii] = eq_reidx[ e_residual[ii]]
135+
new_v_residual[ii] = var_reidx[v_residual[ii]]
136+
end
137+
# `emptyintvec` is aliased to save memory
138+
# We need them for type stability
139+
newpart = SystemPartition(emptyintvec, emptyintvec, new_e_residual, new_v_residual)
140+
push!(newpartitions, newpart)
141+
end
142+
143+
obseqs = solvars .~ rhss
144+
145+
@set! s.graph = newgraph
146+
@set! s.scc = newscc
147+
@set! s.fullvars = fullvars[active_vars]
148+
@set! s.vartype = s.vartype[active_vars]
149+
@set! s.partitions = newpartitions
150+
@set! s.algeqs = newalgeqs
151+
152+
@set! sys.structure = s
153+
@set! sys.eqs = neweqs
154+
@set! sys.states = newstates
155+
@set! sys.observed = [observed(sys); obseqs]
156+
return sys
157+
end
158+
159+
# removes the solved equations and variables
160+
function compact_graph!(newgraph, graph, visited, eq, req, rvar2reqs, var_reidx, active_vars)
161+
for ivar in 𝑠neighbors(graph, eq)
162+
# Note that we need to check `ii` against the rhs states to make
163+
# sure we don't run in circles.
164+
visited[ivar] && continue
165+
visited[ivar] = true
166+
167+
if active_vars[ivar]
168+
add_edge!(newgraph, req, var_reidx[ivar])
169+
else
170+
# If a state is reduced, then we go to the rhs and collect
171+
# its states.
172+
for ieq in rvar2reqs[ivar]
173+
compact_graph!(newgraph, graph, visited, ieq, req, rvar2reqs, var_reidx, active_vars)
174+
end
175+
end
176+
end
177+
return nothing
178+
end
179+
180+
"""
181+
tearing(sys; simplify=false)
182+
183+
Tear the nonlinear equations in system. When `simplify=true`, we simplify the
184+
new residual residual equations after tearing. End users are encouraged to call [`structural_simplify`](@ref)
185+
instead, which calls this function internally.
186+
"""
187+
tearing(sys; simplify=false) = tearing_reassemble(tear_graph(algebraic_equations_scc(sys)); simplify=simplify)

src/structural_transformation/tearing.jl

Lines changed: 0 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -19,186 +19,6 @@ function tear_graph(sys)
1919
return sys
2020
end
2121

22-
function tearing_sub(expr, dict, s)
23-
expr = ModelingToolkit.fixpoint_sub(expr, dict)
24-
s ? simplify(expr) : expr
25-
end
26-
27-
function tearing_reassemble(sys; simplify=false)
28-
s = structure(sys)
29-
@unpack fullvars, partitions, var_eq_matching, graph, scc = s
30-
eqs = equations(sys)
31-
32-
### extract partition information
33-
rhss = []
34-
solvars = []
35-
ns, nd = nsrcs(graph), ndsts(graph)
36-
active_eqs = trues(ns)
37-
active_vars = trues(nd)
38-
rvar2reqs = Vector{Vector{Int}}(undef, nd)
39-
#reduction_graph = BipartiteGraph(nsrcs(graph), ndsts(graph), Val(false))
40-
for (ith_scc, partition) in enumerate(partitions)
41-
@unpack e_solved, v_solved, e_residual, v_residual = partition
42-
for ii in eachindex(e_solved)
43-
ieq = e_solved[ii]; ns -= 1
44-
iv = v_solved[ii]; nd -= 1
45-
rvar2reqs[iv] = e_solved
46-
47-
active_eqs[ieq] = false
48-
active_vars[iv] = false
49-
50-
eq = eqs[ieq]
51-
var = fullvars[iv]
52-
rhs = value(solve_for(eq, var; simplify=simplify, check=false))
53-
# if we don't simplify the rhs and the `eq` is not solved properly
54-
(!simplify && occursin(rhs, var)) && (rhs = SymbolicUtils.polynormalize(rhs))
55-
# Since we know `eq` is linear wrt `var`, so the round off must be a
56-
# linear term. We can correct the round off error by a linear
57-
# correction.
58-
rhs -= expand_derivatives(Differential(var)(rhs))*var
59-
@assert !(var in vars(rhs)) """
60-
When solving
61-
$eq
62-
$var remainded in
63-
$rhs.
64-
"""
65-
push!(rhss, rhs)
66-
push!(solvars, var)
67-
end
68-
# DEBUG:
69-
#@show ith_scc solvars .~ rhss
70-
#Main._nlsys[] = eqs[e_solved], fullvars[v_solved]
71-
#ModelingToolkit.topsort_equations(solvars .~ rhss, fullvars)
72-
#empty!(solvars); empty!(rhss)
73-
end
74-
75-
### update SCC
76-
eq_reidx = Vector{Int}(undef, nsrcs(graph))
77-
idx = 0
78-
for (i, active) in enumerate(active_eqs)
79-
eq_reidx[i] = active ? (idx += 1) : -1
80-
end
81-
82-
rmidxs = Int[]
83-
newscc = Vector{Int}[]; sizehint!(newscc, length(scc))
84-
for component′ in newscc
85-
component = copy(component′)
86-
for (idx, eq) in enumerate(component)
87-
if active_eqs[eq]
88-
component[idx] = eq_reidx[eq]
89-
else
90-
push!(rmidxs, idx)
91-
end
92-
end
93-
push!(newscc, component)
94-
deleteat!(component, rmidxs)
95-
empty!(rmidxs)
96-
end
97-
98-
### update graph
99-
var_reidx = Vector{Int}(undef, ndsts(graph))
100-
idx = 0
101-
for (i, active) in enumerate(active_vars)
102-
var_reidx[i] = active ? (idx += 1) : -1
103-
end
104-
105-
newgraph = BipartiteGraph(ns, nd, Val(false))
106-
107-
108-
### update equations
109-
odestats = []
110-
for idx in eachindex(fullvars); isdervar(s, idx) && continue
111-
push!(odestats, fullvars[idx])
112-
end
113-
newstates = setdiff(odestats, solvars)
114-
varidxmap = Dict(newstates .=> 1:length(newstates))
115-
neweqs = Vector{Equation}(undef, ns)
116-
newalgeqs = falses(ns)
117-
118-
dict = Dict(value.(solvars) .=> value.(rhss))
119-
120-
visited = falses(ndsts(graph))
121-
for ieq in Iterators.flatten(scc); active_eqs[ieq] || continue
122-
eq = eqs[ieq]
123-
ridx = eq_reidx[ieq]
124-
125-
fill!(visited, false)
126-
compact_graph!(newgraph, graph, visited, ieq, ridx, rvar2reqs, var_reidx, active_vars)
127-
128-
if isdiffeq(eq)
129-
neweqs[ridx] = eq.lhs ~ tearing_sub(eq.rhs, dict, simplify)
130-
else
131-
newalgeqs[ridx] = true
132-
if !(eq.lhs isa Number && eq.lhs != 0)
133-
eq = 0 ~ eq.rhs - eq.lhs
134-
end
135-
rhs = tearing_sub(eq.rhs, dict, simplify)
136-
if rhs isa Symbolic
137-
neweqs[ridx] = 0 ~ rhs
138-
else # a number
139-
if abs(rhs) > 100eps(float(rhs))
140-
@warn "The equation $eq is not consistent. It simplifed to 0 == $rhs."
141-
end
142-
neweqs[ridx] = 0 ~ fullvars[invview(var_eq_matching)[ieq]]
143-
end
144-
end
145-
end
146-
147-
### update partitions
148-
newpartitions = similar(partitions, 0)
149-
emptyintvec = Int[]
150-
for (ii, partition) in enumerate(partitions)
151-
@unpack e_residual, v_residual = partition
152-
isempty(v_residual) && continue
153-
new_e_residual = similar(e_residual)
154-
new_v_residual = similar(v_residual)
155-
for ii in eachindex(e_residual)
156-
new_e_residual[ii] = eq_reidx[ e_residual[ii]]
157-
new_v_residual[ii] = var_reidx[v_residual[ii]]
158-
end
159-
# `emptyintvec` is aliased to save memory
160-
# We need them for type stability
161-
newpart = SystemPartition(emptyintvec, emptyintvec, new_e_residual, new_v_residual)
162-
push!(newpartitions, newpart)
163-
end
164-
165-
obseqs = solvars .~ rhss
166-
167-
@set! s.graph = newgraph
168-
@set! s.scc = newscc
169-
@set! s.fullvars = fullvars[active_vars]
170-
@set! s.vartype = s.vartype[active_vars]
171-
@set! s.partitions = newpartitions
172-
@set! s.algeqs = newalgeqs
173-
174-
@set! sys.structure = s
175-
@set! sys.eqs = neweqs
176-
@set! sys.states = newstates
177-
@set! sys.observed = [observed(sys); obseqs]
178-
return sys
179-
end
180-
181-
# removes the solved equations and variables
182-
function compact_graph!(newgraph, graph, visited, eq, req, rvar2reqs, var_reidx, active_vars)
183-
for ivar in 𝑠neighbors(graph, eq)
184-
# Note that we need to check `ii` against the rhs states to make
185-
# sure we don't run in circles.
186-
visited[ivar] && continue
187-
visited[ivar] = true
188-
189-
if active_vars[ivar]
190-
add_edge!(newgraph, req, var_reidx[ivar])
191-
else
192-
# If a state is reduced, then we go to the rhs and collect
193-
# its states.
194-
for ieq in rvar2reqs[ivar]
195-
compact_graph!(newgraph, graph, visited, ieq, req, rvar2reqs, var_reidx, active_vars)
196-
end
197-
end
198-
end
199-
return nothing
200-
end
201-
20222
"""
20323
algebraic_equations_scc(sys)
20424
@@ -220,12 +40,3 @@ function algebraic_equations_scc(sys)
22040
@set! sys.structure.scc = components
22141
return sys
22242
end
223-
224-
"""
225-
tearing(sys; simplify=false)
226-
227-
Tear the nonlinear equations in system. When `simplify=true`, we simplify the
228-
new residual residual equations after tearing. End users are encouraged to call [`structural_simplify`](@ref)
229-
instead, which calls this function internally.
230-
"""
231-
tearing(sys; simplify=false) = tearing_reassemble(tear_graph(algebraic_equations_scc(sys)); simplify=simplify)

0 commit comments

Comments
 (0)