Skip to content

Commit e13068d

Browse files
committed
Fix state selection
1 parent 57dffe6 commit e13068d

File tree

8 files changed

+271
-98
lines changed

8 files changed

+271
-98
lines changed

src/bipartite_graph.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,25 @@ function Graphs.add_edge!(g::BipartiteGraph, edge::BipartiteEdge, md=NO_METADATA
306306
return true # edge successfully added
307307
end
308308

309+
Graphs.rem_edge!(g::BipartiteGraph, i::Integer, j::Integer) =
310+
Graphs.rem_edge!(g, BipartiteEdge(i, j))
311+
function Graphs.rem_edge!(g::BipartiteGraph, edge::BipartiteEdge)
312+
@unpack fadjlist, badjlist = g
313+
s, d = src(edge), dst(edge)
314+
(has_𝑠vertex(g, s) && has_𝑑vertex(g, d)) || error("edge ($edge) out of range.")
315+
@inbounds list = fadjlist[s]
316+
index = searchsortedfirst(list, d)
317+
@inbounds (index <= length(list) && list[index] == d) || error("graph does not have edge $edge")
318+
deleteat!(list, index)
319+
g.ne -= 1
320+
if badjlist isa AbstractVector
321+
@inbounds list = badjlist[d]
322+
index = searchsortedfirst(list, s)
323+
deleteat!(list, index)
324+
end
325+
return true # edge successfully deleted
326+
end
327+
309328
function Graphs.add_vertex!(g::BipartiteGraph{T}, type::VertType) where T
310329
if type === DST
311330
if g.badjlist isa AbstractVector

src/structural_transformation/bipartite_tearing/modia_tearing.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
#
1414
################################################
1515

16+
function try_assign_eq!(ict::IncrementalCycleTracker, vj::Integer, eq::Integer)
17+
G = ict.graph
18+
add_edge_checked!(ict, Iterators.filter(!=(vj), 𝑠neighbors(G.graph, eq)), vj) do G
19+
G.matching[vj] = eq
20+
G.ne += length(𝑠neighbors(G.graph, eq)) - 1
21+
end
22+
end
23+
1624
"""
1725
(eSolved, vSolved, eResidue, vTear) = tearEquations!(td, Gsolvable, es, vs)
1826
@@ -33,10 +41,7 @@ function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int}
3341
for eq in es # iterate only over equations that are not in eSolvedFixed
3442
for vj in Gsolvable[eq]
3543
if G.matching[vj] === unassigned && (vj in vActive)
36-
r = add_edge_checked!(ict, Iterators.filter(!=(vj), 𝑠neighbors(G.graph, eq)), vj) do G
37-
G.matching[vj] = eq
38-
G.ne += length(𝑠neighbors(G.graph, eq)) - 1
39-
end
44+
r = try_assign_eq!(ict, vj, eq)
4045
r && break
4146
end
4247
end
@@ -45,6 +50,15 @@ function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int}
4550
return ict
4651
end
4752

53+
function tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, eqs, vars)
54+
ict = IncrementalCycleTracker(DiCMOBiGraph{true}(graph); dir=:in)
55+
tearEquations!(ict, solvable_graph.fadjlist, eqs, vars)
56+
for var in vars
57+
var_eq_matching[var] = ict.graph.matching[var]
58+
end
59+
return nothing
60+
end
61+
4862
"""
4963
tear_graph_modia(sys) -> sys
5064
@@ -58,13 +72,10 @@ function tear_graph_modia(graph::BipartiteGraph, solvable_graph::BipartiteGraph;
5872
for vars in var_sccs
5973
filtered_vars = filter(varfilter, vars)
6074
ieqs = Int[var_eq_matching[v] for v in filtered_vars if var_eq_matching[v] !== unassigned]
61-
62-
ict = IncrementalCycleTracker(DiCMOBiGraph{true}(graph); dir=:in)
63-
tearEquations!(ict, solvable_graph.fadjlist, ieqs, filtered_vars)
64-
6575
for var in vars
66-
var_eq_matching[var] = ict.graph.matching[var]
76+
var_eq_matching[var] = unassigned
6777
end
78+
tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, ieqs, filtered_vars)
6879
end
6980

7081
return var_eq_matching

src/structural_transformation/pantelides.jl

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,33 @@ function pantelides!(sys::ODESystem; maxiters = 8000)
7575
find_solvables!(sys)
7676
s = structure(sys)
7777
# D(j) = assoc[j]
78-
@unpack graph, var_to_diff, solvable_graph = s
79-
return (sys, pantelides!(graph, solvable_graph, var_to_diff)...)
78+
@unpack graph, var_to_diff = s
79+
# N.B.: var_derivative! and eq_derivative! are defined in symbolics_tearing.jl
80+
return (sys, pantelides!(PantelidesSetup(sys, graph, var_to_diff))...)
8081
end
8182

82-
function pantelides!(graph, solvable_graph, var_to_diff; maxiters = 8000)
83+
struct PantelidesSetup{T}
84+
system::T
85+
graph::BipartiteGraph
86+
var_to_diff::DiffGraph
87+
eq_to_diff::DiffGraph
88+
var_eq_matching::Matching
89+
end
90+
91+
function PantelidesSetup(sys::T, graph, var_to_diff) where {T}
8392
neqs = nsrcs(graph)
8493
nvars = nv(var_to_diff)
85-
vcolor = falses(nvars)
86-
ecolor = falses(neqs)
8794
var_eq_matching = Matching(nvars)
8895
eq_to_diff = DiffGraph(neqs)
96+
PantelidesSetup{T}(sys, graph, var_to_diff, eq_to_diff, var_eq_matching)
97+
end
98+
99+
function pantelides!(p::PantelidesSetup; maxiters = 8000)
100+
@unpack graph, var_to_diff, eq_to_diff, var_eq_matching = p
101+
neqs = nsrcs(graph)
102+
nvars = nv(var_to_diff)
103+
vcolor = falses(nvars)
104+
ecolor = falses(neqs)
89105
neqs′ = neqs
90106
for k in 1:neqs′
91107
eq′ = k
@@ -107,27 +123,22 @@ function pantelides!(graph, solvable_graph, var_to_diff; maxiters = 8000)
107123
for var in eachindex(vcolor); vcolor[var] || continue
108124
# introduce a new variable
109125
nvars += 1
110-
add_vertex!(graph, DST); add_vertex!(solvable_graph, DST)
126+
add_vertex!(graph, DST);
111127
# the new variable is the derivative of `var`
112128

113129
add_edge!(var_to_diff, var, add_vertex!(var_to_diff))
114130
push!(var_eq_matching, unassigned)
131+
var_derivative!(p, eq)
115132
end
116133

117134
for eq in eachindex(ecolor); ecolor[eq] || continue
118135
# introduce a new equation
119136
neqs += 1
120-
add_vertex!(graph, SRC); add_vertex!(solvable_graph, SRC)
137+
add_vertex!(graph, SRC);
121138
# the new equation is created by differentiating `eq`
122139
eq_diff = add_vertex!(eq_to_diff)
123140
add_edge!(eq_to_diff, eq, eq_diff)
124-
for var in 𝑠neighbors(graph, eq)
125-
add_edge!(graph, eq_diff, var)
126-
add_edge!(graph, eq_diff, var_to_diff[var])
127-
# If you have f(x) = 0, then the derivative is (∂f/∂x) ẋ = 0.
128-
# which is linear, thus solvable in ẋ.
129-
add_edge!(solvable_graph, eq_diff, var_to_diff[var])
130-
end
141+
eq_derivative!(p, eq)
131142
end
132143

133144
for var in eachindex(vcolor); vcolor[var] || continue

src/structural_transformation/partial_state_selection.jl

Lines changed: 108 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,108 @@ function partial_state_selection_graph!(sys::ODESystem)
22
s = get_structure(sys)
33
(s isa SystemStructure) || (sys = initialize_system_structure(sys))
44
s = structure(sys)
5-
find_solvables!(sys)
5+
find_solvables!(sys; allow_symbolic=true)
66
@set! s.graph = complete(s.graph)
77
@set! sys.structure = s
8-
(sys, partial_state_selection_graph!(s.graph, s.solvable_graph, s.var_to_diff)...)
8+
var_eq_matching, eq_to_diff = pantelides!(PantelidesSetup(sys, s.graph, s.var_to_diff))
9+
(sys, partial_state_selection_graph!(s.graph, s.solvable_graph, s.var_to_diff, var_eq_matching, eq_to_diff)...)
10+
end
11+
12+
function ascend_dg(xs, dg, level)
13+
while level > 0
14+
xs = Int[dg[x] for x in xs]
15+
level -= 1
16+
end
17+
return xs
18+
end
19+
20+
function ascend_dg_all(xs, dg, level, maxlevel)
21+
r = Int[]
22+
while true
23+
if level <= 0
24+
append!(r, xs)
25+
end
26+
maxlevel <= 0 && break
27+
xs = Int[dg[x] for x in xs if dg[x] !== nothing]
28+
level -= 1
29+
maxlevel -= 1
30+
end
31+
return r
32+
end
33+
34+
function pss_graph_modia!(graph, solvable_graph, var_eq_matching, var_to_diff, eq_to_diff, varlevel, inv_varlevel, inv_eqlevel)
35+
# var_eq_matching is a maximal matching on the top-differentiated variables.
36+
# Find Strongly connected components. Note that after pantelides, we expect
37+
# a balanced system, so a maximal matching should be possible.
38+
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, var_eq_matching)
39+
var_eq_matching = Matching{Union{Unassigned, SelectedState}}(var_eq_matching)
40+
for vars in var_sccs
41+
# TODO: We should have a way to not have the scc code look at unassigned vars.
42+
if length(vars) == 1 && varlevel[vars[1]] != 0
43+
continue
44+
end
45+
46+
# Now proceed level by level from lowest to highest and tear the graph.
47+
eqs = [var_eq_matching[var] for var in vars]
48+
maxlevel = level = maximum(map(x->inv_eqlevel[x], eqs))
49+
old_level_vars = ()
50+
ict = IncrementalCycleTracker(DiCMOBiGraph{true}(graph, complete(Matching(ndsts(graph)))); dir=:in)
51+
while level >= 0
52+
to_tear_eqs_toplevel = filter(eq->inv_eqlevel[eq] >= level, eqs)
53+
to_tear_eqs = ascend_dg(to_tear_eqs_toplevel, invview(eq_to_diff), level)
54+
55+
to_tear_vars_toplevel = filter(var->inv_varlevel[var] >= level, vars)
56+
to_tear_vars = ascend_dg_all(to_tear_vars_toplevel, invview(var_to_diff), level, maxlevel)
57+
58+
if old_level_vars !== ()
59+
# Inherit constraints from previous level.
60+
# TODO: Is this actually a good idea or do we want full freedom
61+
# to tear differently on each level? Does it make a difference
62+
# whether we're using heuristic or optimal tearing?
63+
removed_eqs = Int[]
64+
removed_vars = Int[]
65+
for var in old_level_vars
66+
old_assign = ict.graph.matching[var]
67+
if !isa(old_assign, Int) || ict.graph.matching[var_to_diff[var]] !== unassigned
68+
continue
69+
end
70+
# Make sure the ict knows about this edge, so it doesn't accidentally introduce
71+
# a cycle.
72+
ok = try_assign_eq!(ict, var_to_diff[var], eq_to_diff[old_assign])
73+
@assert ok
74+
var_eq_matching[var_to_diff[var]] = eq_to_diff[old_assign]
75+
push!(removed_eqs, eq_to_diff[ict.graph.matching[var]])
76+
push!(removed_vars, var_to_diff[var])
77+
end
78+
to_tear_eqs = setdiff(to_tear_eqs, removed_eqs)
79+
to_tear_vars = setdiff(to_tear_vars, removed_vars)
80+
end
81+
filter!(var->ict.graph.matching[var] === unassigned, to_tear_vars)
82+
filter!(eq->invview(ict.graph.matching)[eq] === unassigned, to_tear_eqs)
83+
tearEquations!(ict, solvable_graph.fadjlist, to_tear_eqs, to_tear_vars)
84+
for var in to_tear_vars
85+
var_eq_matching[var] = ict.graph.matching[var]
86+
end
87+
old_level_vars = to_tear_vars
88+
level -= 1
89+
end
90+
for var in old_level_vars
91+
if varlevel[var] !== 0 && var_eq_matching[var] === unassigned
92+
var_eq_matching[var] = SelectedState()
93+
end
94+
end
95+
end
96+
return var_eq_matching
997
end
1098

1199
struct SelectedState; end
12-
function partial_state_selection_graph!(graph, solvable_graph, var_to_diff)
13-
var_eq_matching, eq_to_diff = pantelides!(graph, solvable_graph, var_to_diff)
100+
function partial_state_selection_graph!(graph, solvable_graph, var_to_diff, var_eq_matching, eq_to_diff)
14101
eq_to_diff = complete(eq_to_diff)
15102

16-
eqlevel = map(1:nsrcs(graph)) do eq
103+
inv_eqlevel = map(1:nsrcs(graph)) do eq
17104
level = 0
18-
while eq_to_diff[eq] !== nothing
19-
eq = eq_to_diff[eq]
105+
while invview(eq_to_diff)[eq] !== nothing
106+
eq = invview(eq_to_diff)[eq]
20107
level += 1
21108
end
22109
level
@@ -31,45 +118,25 @@ function partial_state_selection_graph!(graph, solvable_graph, var_to_diff)
31118
level
32119
end
33120

34-
all_selected_states = Int[]
35-
36-
level = 0
37-
level_vars = [var for var in 1:ndsts(graph) if varlevel[var] == 0 && invview(var_to_diff)[var] !== nothing]
121+
inv_varlevel = map(1:ndsts(graph)) do var
122+
level = 0
123+
while invview(var_to_diff)[var] !== nothing
124+
var = invview(var_to_diff)[var]
125+
level += 1
126+
end
127+
level
128+
end
38129

39-
# TODO: Is this actually useful or should we just compute another maximal matching?
130+
# TODO: Should pantelides just return this?
40131
for var in 1:ndsts(graph)
41-
if !(var in level_vars)
132+
if var_to_diff[var] !== nothing
42133
var_eq_matching[var] = unassigned
43134
end
44135
end
45136

46-
while level < maximum(eqlevel)
47-
var_eq_matching = tear_graph_modia(graph, solvable_graph;
48-
eqfilter = eq->eqlevel[eq] == level && invview(eq_to_diff)[eq] !== nothing,
49-
varfilter = var->(var in level_vars && !(var in all_selected_states)))
50-
for var in level_vars
51-
if var_eq_matching[var] === unassigned
52-
selected_state = invview(var_to_diff)[var]
53-
push!(all_selected_states, selected_state)
54-
#=
55-
# TODO: This is what the Matteson paper says, but it doesn't
56-
# quite seem to work.
57-
while selected_state !== nothing
58-
push!(all_selected_states, selected_state)
59-
selected_state = invview(var_to_diff)[selected_state]
60-
end
61-
=#
62-
end
63-
end
64-
level += 1
65-
level_vars = [var for var = 1:ndsts(graph) if varlevel[var] == level && invview(var_to_diff)[var] !== nothing]
66-
end
137+
var_eq_matching = pss_graph_modia!(graph, solvable_graph,
138+
complete(var_eq_matching), var_to_diff, eq_to_diff, varlevel, inv_varlevel,
139+
inv_eqlevel)
67140

68-
var_eq_matching = tear_graph_modia(graph, solvable_graph;
69-
varfilter = var->!(var in all_selected_states))
70-
var_eq_matching = Matching{Union{Unassigned, SelectedState}}(var_eq_matching)
71-
for var in all_selected_states
72-
var_eq_matching[var] = SelectedState()
73-
end
74-
return var_eq_matching, eq_to_diff
141+
var_eq_matching, eq_to_diff
75142
end

0 commit comments

Comments
 (0)