Skip to content

Commit 2f88927

Browse files
authored
Merge pull request #1350 from Keno/kf/assignstructs
Refactor (inv_)assign/{var/eq}assoc variables into proper structs
2 parents ef30028 + aa37e87 commit 2f88927

File tree

12 files changed

+224
-144
lines changed

12 files changed

+224
-144
lines changed

src/bipartite_graph.jl

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
module BipartiteGraphs
22

3-
export BipartiteEdge, BipartiteGraph, DiCMOBiGraph, Unassigned, unassigned
3+
export BipartiteEdge, BipartiteGraph, DiCMOBiGraph, Unassigned, unassigned,
4+
Matching
45

56
export 𝑠vertices, 𝑑vertices, has_𝑠vertex, has_𝑑vertex, 𝑠neighbors, 𝑑neighbors,
6-
𝑠edges, 𝑑edges, nsrcs, ndsts, SRC, DST, set_neighbors!
7+
𝑠edges, 𝑑edges, nsrcs, ndsts, SRC, DST, set_neighbors!, invview,
8+
complete
79

810
using DocStringExtensions
911
using UnPack
@@ -17,6 +19,47 @@ struct Unassigned
1719
const unassigned = Unassigned.instance
1820
end
1921

22+
struct Matching{V<:AbstractVector{<:Union{Unassigned, Int}}} <: AbstractVector{Union{Unassigned, Int}}
23+
match::V
24+
inv_match::Union{Nothing, V}
25+
end
26+
Matching(v::V) where {V<:AbstractVector{<:Union{Unassigned, Int}}} =
27+
Matching{V}(v, nothing)
28+
Matching(m::Int) = Matching(Union{Int, Unassigned}[unassigned for _ = 1:m], nothing)
29+
Matching(m::Matching) = m
30+
31+
Base.size(m::Matching) = Base.size(m.match)
32+
Base.getindex(m::Matching, i::Integer) = m.match[i]
33+
Base.iterate(m::Matching, state...) = iterate(m.match, state...)
34+
function Base.setindex!(m::Matching, v::Integer, i::Integer)
35+
if m.inv_match !== nothing
36+
m.inv_match[v] = i
37+
end
38+
return m.match[i] = v
39+
end
40+
41+
function Base.push!(m::Matching, v::Union{Integer, Unassigned})
42+
push!(m.match, v)
43+
if v !== unassigned && m.inv_match !== nothing
44+
m.inv_match[v] = length(m.match)
45+
end
46+
end
47+
48+
function complete(m::Matching)
49+
m.inv_match !== nothing && return m
50+
inv_match = Union{Unassigned, Int}[unassigned for _ = 1:length(m.match)]
51+
for (i, eq) in enumerate(m.match)
52+
eq === unassigned && continue
53+
inv_match[eq] = i
54+
end
55+
return Matching(collect(m.match), inv_match)
56+
end
57+
58+
function invview(m::Matching)
59+
m.inv_match === nothing && throw(ArgumentError("Backwards matching not defined. `complete` the matching first."))
60+
return Matching(m.inv_match, m.match)
61+
end
62+
2063
###
2164
### Edges & Vertex
2265
###
@@ -291,15 +334,15 @@ is acyclic if and only if the induced directed graph on the original bipartite
291334
graph is acyclic.
292335
293336
"""
294-
mutable struct DiCMOBiGraph{Transposed, I, G<:BipartiteGraph{I}, M} <: Graphs.AbstractGraph{I}
337+
mutable struct DiCMOBiGraph{Transposed, I, G<:BipartiteGraph{I}, M <: Matching} <: Graphs.AbstractGraph{I}
295338
graph::G
296339
ne::Union{Missing, Int}
297340
matching::M
298341
DiCMOBiGraph{Transposed}(g::G, ne::Union{Missing, Int}, m::M) where {Transposed, I, G<:BipartiteGraph{I}, M} =
299342
new{Transposed, I, G, M}(g, ne, m)
300343
end
301344
function DiCMOBiGraph{Transposed}(g::BipartiteGraph) where {Transposed}
302-
DiCMOBiGraph{Transposed}(g, 0, Union{Unassigned, Int}[unassigned for i = 1:ndsts(g)])
345+
DiCMOBiGraph{Transposed}(g, 0, Matching(ndsts(g)))
303346
end
304347
function DiCMOBiGraph{Transposed}(g::BipartiteGraph, m::M) where {Transposed, M}
305348
DiCMOBiGraph{Transposed}(g, missing, m)

src/compat/incremental_cycles.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ for a usage example.
1212
"""
1313
abstract type IncrementalCycleTracker{I} <: AbstractGraph{I} end
1414

15-
function (::Type{IncrementalCycleTracker})(s::AbstractGraph{I}; in_out_reverse=nothing) where {I}
15+
function (::Type{IncrementalCycleTracker})(s::AbstractGraph{I}; dir=:out) where {I}
1616
# TODO: Once we have more algorithms, the poly-algorithm decision goes here.
1717
# For now, we only have Algorithm N.
18-
return DenseGraphICT_BFGT_N{something(in_out_reverse, false)}(s)
18+
return DenseGraphICT_BFGT_N{something(dir == :in, false)}(s)
1919
end
2020

2121
# Cycle Detection Interface
@@ -116,7 +116,7 @@ function Base.setindex!(vec::TransactionalVector, val, idx)
116116
return nothing
117117
end
118118
Base.getindex(vec::TransactionalVector, idx) = vec.v[idx]
119-
Base.size(vec) = size(vec.v)
119+
Base.size(vec::TransactionalVector) = size(vec.v)
120120

121121
# Specific Algorithms
122122

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using ModelingToolkit: ODESystem, AbstractSystem,var_from_nested_derivative, Dif
2222
IncrementalCycleTracker, add_edge_checked!, topological_sort
2323

2424
using ModelingToolkit.BipartiteGraphs
25+
import .BipartiteGraphs: invview
2526
using Graphs
2627
using ModelingToolkit.SystemStructures
2728

src/structural_transformation/pantelides.jl

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,43 +2,39 @@
22
### Reassemble: structural information -> system
33
###
44

5-
function pantelides_reassemble(sys::ODESystem, eqassoc, assign)
5+
function pantelides_reassemble(sys::ODESystem, eq_to_diff, assign)
66
s = structure(sys)
7-
@unpack fullvars, varassoc = s
7+
@unpack fullvars, var_to_diff = s
88
# Step 1: write derivative equations
99
in_eqs = equations(sys)
10-
out_eqs = Vector{Any}(undef, length(eqassoc))
10+
out_eqs = Vector{Any}(undef, nv(eq_to_diff))
1111
fill!(out_eqs, nothing)
1212
out_eqs[1:length(in_eqs)] .= in_eqs
1313

14-
out_vars = Vector{Any}(undef, length(varassoc))
14+
out_vars = Vector{Any}(undef, nv(var_to_diff))
1515
fill!(out_vars, nothing)
1616
out_vars[1:length(fullvars)] .= fullvars
1717

1818
D = Differential(get_iv(sys))
1919

20-
for (i, v) in enumerate(varassoc)
21-
# fullvars[v] = D(fullvars[i])
22-
v == 0 && continue
23-
vi = out_vars[i]
20+
for (varidx, diff) in edges(var_to_diff)
21+
# fullvars[diff] = D(fullvars[var])
22+
vi = out_vars[varidx]
2423
@assert vi !== nothing "Something went wrong on reconstructing states from variable association list"
2524
# `fullvars[i]` needs to be not a `D(...)`, because we want the DAE to be
2625
# first-order.
2726
if isdifferential(vi)
28-
vi = out_vars[i] = diff2term(vi)
27+
vi = out_vars[varidx] = diff2term(vi)
2928
end
30-
out_vars[v] = D(vi)
29+
out_vars[diff] = D(vi)
3130
end
3231

3332
d_dict = Dict(zip(fullvars, 1:length(fullvars)))
3433
lhss = Set{Any}([x.lhs for x in in_eqs if isdiffeq(x)])
35-
for (i, e) in enumerate(eqassoc)
36-
if e === 0
37-
continue
38-
end
39-
# LHS variable is looked up from varassoc
40-
# the varassoc[i]-th variable is the differentiated version of var at i
41-
eq = out_eqs[i]
34+
for (eqidx, diff) in edges(eq_to_diff)
35+
# LHS variable is looked up from var_to_diff
36+
# the var_to_diff[i]-th variable is the differentiated version of var at i
37+
eq = out_eqs[eqidx]
4238
lhs = if !(eq.lhs isa Symbolic)
4339
0
4440
elseif isdiffeq(eq)
@@ -58,7 +54,7 @@ function pantelides_reassemble(sys::ODESystem, eqassoc, assign)
5854
rhs = ModelingToolkit.expand_derivatives(D(eq.rhs))
5955
substitution_dict = Dict(x.lhs => x.rhs for x in out_eqs if x !== nothing && x.lhs isa Symbolic)
6056
sub_rhs = substitute(rhs, substitution_dict)
61-
out_eqs[e] = lhs ~ sub_rhs
57+
out_eqs[diff] = lhs ~ sub_rhs
6258
end
6359

6460
final_vars = unique(filter(x->!(operation(x) isa Differential), fullvars))
@@ -78,16 +74,18 @@ Perform Pantelides algorithm.
7874
function pantelides!(sys::ODESystem; maxiters = 8000)
7975
s = structure(sys)
8076
# D(j) = assoc[j]
81-
@unpack graph, fullvars, varassoc = s
82-
iv = get_iv(sys)
77+
@unpack graph, var_to_diff = s
78+
return (sys, pantelides!(graph, var_to_diff)...)
79+
end
80+
81+
function pantelides!(graph, var_to_diff; maxiters = 8000)
8382
neqs = nsrcs(graph)
84-
nvars = length(varassoc)
83+
nvars = nv(var_to_diff)
8584
vcolor = falses(nvars)
8685
ecolor = falses(neqs)
87-
assign = Union{Unassigned, Int}[unassigned for _ = 1:nvars]
88-
eqassoc = fill(0, neqs)
86+
var_eq_matching = Matching(nvars)
87+
eq_to_diff = DiffGraph(neqs)
8988
neqs′ = neqs
90-
D = Differential(iv)
9189
for k in 1:neqs′
9290
eq′ = k
9391
pathfound = false
@@ -98,46 +96,46 @@ function pantelides!(sys::ODESystem; maxiters = 8000)
9896
#
9997
# the derivatives and algebraic variables are zeros in the variable
10098
# association list
101-
varwhitelist = varassoc .== 0
99+
varwhitelist = var_to_diff .== nothing
102100
resize!(vcolor, nvars)
103101
fill!(vcolor, false)
104102
resize!(ecolor, neqs)
105103
fill!(ecolor, false)
106-
pathfound = find_augmenting_path(graph, eq′, assign, varwhitelist, vcolor, ecolor)
104+
pathfound = find_augmenting_path(graph, eq′, var_eq_matching, varwhitelist, vcolor, ecolor)
107105
pathfound && break # terminating condition
108106
for var in eachindex(vcolor); vcolor[var] || continue
109107
# introduce a new variable
110108
nvars += 1
111109
add_vertex!(graph, DST)
112110
# the new variable is the derivative of `var`
113-
varassoc[var] = nvars
114-
push!(varassoc, 0)
115-
push!(assign, unassigned)
111+
112+
add_edge!(var_to_diff, var, add_vertex!(var_to_diff))
113+
push!(var_eq_matching, unassigned)
116114
end
117115

118116
for eq in eachindex(ecolor); ecolor[eq] || continue
119117
# introduce a new equation
120118
neqs += 1
121119
add_vertex!(graph, SRC)
122120
# the new equation is created by differentiating `eq`
123-
eqassoc[eq] = neqs
121+
eq_diff = add_vertex!(eq_to_diff)
122+
add_edge!(eq_to_diff, eq, eq_diff)
124123
for var in 𝑠neighbors(graph, eq)
125-
add_edge!(graph, neqs, var)
126-
add_edge!(graph, neqs, varassoc[var])
124+
add_edge!(graph, eq_diff, var)
125+
add_edge!(graph, eq_diff, var_to_diff[var])
127126
end
128-
push!(eqassoc, 0)
129127
end
130128

131129
for var in eachindex(vcolor); vcolor[var] || continue
132130
# the newly introduced `var`s and `eq`s have the inherits
133131
# assignment
134-
assign[varassoc[var]] = eqassoc[assign[var]]
132+
var_eq_matching[var_to_diff[var]] = eq_to_diff[var_eq_matching[var]]
135133
end
136-
eq′ = eqassoc[eq′]
134+
eq′ = eq_to_diff[eq′]
137135
end # for _ in 1:maxiters
138136
pathfound || error("maxiters=$maxiters reached! File a bug report if your system has a reasonable index (<100), and you are using the default `maxiters`. Try to increase the maxiters by `pantelides(sys::ODESystem; maxiters=1_000_000)` if your system has an incredibly high index and it is truly extremely large.")
139137
end # for k in 1:neqs′
140-
return sys, assign, eqassoc
138+
return var_eq_matching, eq_to_diff
141139
end
142140

143141
"""
@@ -150,6 +148,6 @@ instead, which calls this function internally.
150148
function dae_index_lowering(sys::ODESystem; kwargs...)
151149
s = get_structure(sys)
152150
(s isa SystemStructure) || (sys = initialize_system_structure(sys))
153-
sys, assign, eqassoc = pantelides!(sys; kwargs...)
154-
return pantelides_reassemble(sys, eqassoc, assign)
151+
sys, var_eq_matching, eq_to_diff = pantelides!(sys; kwargs...)
152+
return pantelides_reassemble(sys, eq_to_diff, var_eq_matching)
155153
end

src/structural_transformation/tearing.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ instead, which calls this function internally.
77
function tear_graph(sys)
88
find_solvables!(sys)
99
s = structure(sys)
10-
@unpack graph, solvable_graph, assign, inv_assign, scc = s
10+
@unpack graph, solvable_graph, var_eq_matching, scc = s
1111

1212
@set! sys.structure.partitions = map(scc) do c
1313
ieqs = filter(eq->isalgeq(s, eq), c)
14-
vars = inv_assign[ieqs]
14+
vars = Int[var for var in invview(var_eq_matching)[ieqs] if var !== unassigned]
1515

16-
ict = IncrementalCycleTracker(DiCMOBiGraph{true}(graph); in_out_reverse=true)
16+
ict = IncrementalCycleTracker(DiCMOBiGraph{true}(graph); dir=:in)
1717
SystemPartition(tearEquations!(ict, solvable_graph.fadjlist, ieqs, vars)...)
1818
end
1919
return sys
@@ -26,7 +26,7 @@ end
2626

2727
function tearing_reassemble(sys; simplify=false)
2828
s = structure(sys)
29-
@unpack fullvars, partitions, assign, inv_assign, graph, scc = s
29+
@unpack fullvars, partitions, var_eq_matching, graph, scc = s
3030
eqs = equations(sys)
3131

3232
### extract partition information
@@ -152,7 +152,7 @@ function tearing_reassemble(sys; simplify=false)
152152
if abs(rhs) > 100eps(float(rhs))
153153
@warn "The equation $eq is not consistent. It simplifed to 0 == $rhs."
154154
end
155-
neweqs[ridx] = 0 ~ fullvars[inv_assign[ieq]]
155+
neweqs[ridx] = 0 ~ fullvars[invview(var_eq_matching)[ieq]]
156156
end
157157
end
158158
end
@@ -205,14 +205,10 @@ function algebraic_equations_scc(sys)
205205

206206
# skip over differential equations
207207
algvars = isalgvar.(Ref(s), 1:ndsts(s.graph))
208-
eqs = equations(sys)
209-
assign = matching(s, algvars, s.algeqs)
210-
211-
components = find_scc(s.graph, assign)
212-
inv_assign = inverse_mapping(assign)
208+
var_eq_matching = complete(matching(s, algvars, s.algeqs))
209+
components = find_scc(s.graph, var_eq_matching)
213210

214-
@set! sys.structure.assign = assign
215-
@set! sys.structure.inv_assign = inv_assign
211+
@set! sys.structure.var_eq_matching = var_eq_matching
216212
@set! sys.structure.scc = components
217213
return sys
218214
end

0 commit comments

Comments
 (0)