Skip to content

Commit c6e0f44

Browse files
committed
Replace open coded tarjan algorithm by Graphs.jl version
Besides avoiding redundancy and making used of the more optimized version in Graphs.jl, I think the extra abstraction also gives an insight into what exactly the induced orientation of the graph is that we're using to find strongly connected components. While we're at it also replace the hardcoded integer sentintel by a singleton type to align more with how we're doing this elsewhere in Julia.
1 parent c36c429 commit c6e0f44

File tree

6 files changed

+73
-80
lines changed

6 files changed

+73
-80
lines changed

src/bipartite_graph.jl

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module BipartiteGraphs
22

3-
export BipartiteEdge, BipartiteGraph
3+
export BipartiteEdge, BipartiteGraph, DiCMOBiGraph, Unassigned, unassigned
44

55
export 𝑠vertices, 𝑑vertices, has_𝑠vertex, has_𝑑vertex, 𝑠neighbors, 𝑑neighbors,
66
𝑠edges, 𝑑edges, nsrcs, ndsts, SRC, DST
@@ -11,6 +11,12 @@ using SparseArrays
1111
using Graphs
1212
using Setfield
1313

14+
### Matching
15+
struct Unassigned
16+
global unassigned
17+
const unassigned = Unassigned.instance
18+
end
19+
1420
###
1521
### Edges & Vertex
1622
###
@@ -269,4 +275,56 @@ function Graphs.incidence_matrix(g::BipartiteGraph, val=true)
269275
S = sparse(I, J, val, nsrcs(g), ndsts(g))
270276
end
271277

278+
279+
"""
280+
struct DiCMOBiGraph
281+
282+
This data structure implements a "directed, contracted, matching-oriented" view of an
283+
original (undirected) bipartite graph. In particular, it performs two largely
284+
orthogonal functions.
285+
286+
1. It pairs an undirected bipartite graph with a matching of destination vertex.
287+
288+
This matching is used to induce an orientation on the otherwise undirected graph:
289+
Matched edges pass from destination to source, all other edges pass in the opposite
290+
direction.
291+
292+
2. It exposes the graph view obtained by contracting the destination vertices into
293+
the source edges.
294+
295+
The result of this operation is an induced, directed graph on the source vertices.
296+
The resulting graph has a few desirable properties. In particular, this graph
297+
is acyclic if and only if the induced directed graph on the original bipartite
298+
graph is acyclic.
299+
"""
300+
struct DiCMOBiGraph{I, G<:BipartiteGraph{I}, M} <: Graphs.AbstractGraph{I}
301+
graph::G
302+
matching::M
303+
end
304+
Graphs.is_directed(::Type{<:DiCMOBiGraph}) = true
305+
Graphs.nv(g::DiCMOBiGraph) = nsrcs(g.graph)
306+
Graphs.vertices(g::DiCMOBiGraph) = 1:nsrcs(g.graph)
307+
308+
struct CMOOutNeighbors{V}
309+
g::DiCMOBiGraph
310+
v::V
311+
end
312+
Graphs.outneighbors(g::DiCMOBiGraph, v) = CMOOutNeighbors(g, v)
313+
Base.iterate(c::CMOOutNeighbors) = iterate(c, (c.g.graph.fadjlist[c.v],))
314+
function Base.iterate(c::CMOOutNeighbors, (l, state...))
315+
while true
316+
r = iterate(l, state...)
317+
r === nothing && return nothing
318+
# If this is a matched edge, skip it, it's reversed in the induced
319+
# directed graph. Otherwise, if there is no matching for this destination
320+
# edge, also skip it, since it got delted in the contraction.
321+
vdst = c.g.matching[r[1]]
322+
if vdst === c.v || vdst === unassigned
323+
state = (r[2],)
324+
continue
325+
end
326+
return vdst, (l, r[2])
327+
end
328+
end
329+
272330
end # module

src/structural_transformation/StructuralTransformations.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
module StructuralTransformations
22

3-
const UNVISITED = typemin(Int)
4-
const UNASSIGNED = typemin(Int)
5-
63
using Setfield: @set!, @set
74
using UnPack: @unpack
85

src/structural_transformation/pantelides.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ function pantelides_reassemble(sys::ODESystem, eqassoc, assign)
6262
end
6363

6464
final_vars = unique(filter(x->!(operation(x) isa Differential), fullvars))
65-
final_eqs = map(identity, filter(x->value(x.lhs) !== nothing, out_eqs[sort(filter(x->x != UNASSIGNED, assign))]))
65+
final_eqs = map(identity, filter(x->value(x.lhs) !== nothing, out_eqs[sort(filter(x->x !== unassigned, assign))]))
6666

6767
@set! sys.eqs = final_eqs
6868
@set! sys.states = final_vars
@@ -84,7 +84,7 @@ function pantelides!(sys::ODESystem; maxiters = 8000)
8484
nvars = length(varassoc)
8585
vcolor = falses(nvars)
8686
ecolor = falses(neqs)
87-
assign = fill(UNASSIGNED, nvars)
87+
assign = Union{Unassigned, Int}[unassigned for _ = 1:nvars]
8888
eqassoc = fill(0, neqs)
8989
neqs′ = neqs
9090
D = Differential(iv)
@@ -112,7 +112,7 @@ function pantelides!(sys::ODESystem; maxiters = 8000)
112112
# the new variable is the derivative of `var`
113113
varassoc[var] = nvars
114114
push!(varassoc, 0)
115-
push!(assign, UNASSIGNED)
115+
push!(assign, unassigned)
116116
end
117117

118118
for eq in eachindex(ecolor); ecolor[eq] || continue

src/structural_transformation/utils.jl

Lines changed: 9 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function find_augmenting_path(g, eq, assign, varwhitelist, vcolor=falses(ndsts(g
1212

1313
# if a `var` is unassigned and the edge `eq <=> var` exists
1414
for var in 𝑠neighbors(g, eq)
15-
if (varwhitelist === nothing || varwhitelist[var]) && assign[var] == UNASSIGNED
15+
if (varwhitelist === nothing || varwhitelist[var]) && assign[var] === unassigned
1616
assign[var] = eq
1717
return true
1818
end
@@ -37,7 +37,7 @@ Find equation-variable bipartite matching. `s.graph` is a bipartite graph.
3737
"""
3838
matching(s::SystemStructure, varwhitelist=nothing, eqwhitelist=nothing) = matching(s.graph, varwhitelist, eqwhitelist)
3939
function matching(g::BipartiteGraph, varwhitelist=nothing, eqwhitelist=nothing)
40-
assign = fill(UNASSIGNED, ndsts(g))
40+
assign = Union{Unassigned, Int}[unassigned for _ = 1:ndsts(g)]
4141
for eq in 𝑠vertices(g)
4242
if eqwhitelist !== nothing
4343
eqwhitelist[eq] || continue
@@ -98,7 +98,7 @@ function check_consistency(sys::AbstractSystem)
9898
inv_assign = inverse_mapping(assign) # extra equations
9999
bad_idxs = findall(iszero, @view inv_assign[1:nsrcs(graph)])
100100
else
101-
bad_idxs = findall(isequal(UNASSIGNED), assign)
101+
bad_idxs = findall(isequal(unassigned), assign)
102102
end
103103
error_reporting(sys, bad_idxs, n_highest_vars, iseqs)
104104
end
@@ -110,7 +110,7 @@ function check_consistency(sys::AbstractSystem)
110110

111111
unassigned_var = []
112112
for (vj, eq) in enumerate(extended_assign)
113-
if eq === UNASSIGNED
113+
if eq === unassigned
114114
push!(unassigned_var, fullvars[vj])
115115
end
116116
end
@@ -149,72 +149,10 @@ gives the undirected bipartite graph a direction. When `assign === nothing`, we
149149
assume that the ``i``-th variable is assigned to the ``i``-th equation.
150150
"""
151151
function find_scc(g::BipartiteGraph, assign=nothing)
152-
id = 0
153-
stack = Int[]
154-
components = Vector{Int}[]
155-
n = nsrcs(g)
156-
onstack = falses(n)
157-
lowlink = zeros(Int, n)
158-
ids = fill(UNVISITED, n)
159-
160-
for eq in 𝑠vertices(g)
161-
if ids[eq] == UNVISITED
162-
id = strongly_connected!(stack, onstack, components, lowlink, ids, g, assign, eq, id)
163-
end
164-
end
165-
return components
166-
end
167-
168-
"""
169-
strongly_connected!(stack, onstack, components, lowlink, ids, g, assign, eq, id)
170-
171-
Use Tarjan's algorithm to find strongly connected components.
172-
"""
173-
function strongly_connected!(stack, onstack, components, lowlink, ids, g, assign, eq, id)
174-
id += 1
175-
lowlink[eq] = ids[eq] = id
176-
177-
# add `eq` to the stack
178-
push!(stack, eq)
179-
onstack[eq] = true
180-
181-
# for `adjeq` in the adjacency list of `eq`
182-
for var in 𝑠neighbors(g, eq)
183-
if assign === nothing
184-
adjeq = var
185-
else
186-
# assign[var] => the equation that's assigned to var
187-
adjeq = assign[var]
188-
# skip equations that are not assigned
189-
adjeq == UNASSIGNED && continue
190-
end
191-
192-
# if `adjeq` is not yet idsed
193-
if ids[adjeq] == UNVISITED # visit unvisited nodes
194-
id = strongly_connected!(stack, onstack, components, lowlink, ids, g, assign, adjeq, id)
195-
end
196-
# at the callback of the DFS
197-
if onstack[adjeq]
198-
lowlink[eq] = min(lowlink[eq], lowlink[adjeq])
199-
end
200-
end
201-
202-
# if we are at a start of a strongly connected component
203-
if lowlink[eq] == ids[eq]
204-
component = Int[]
205-
repeat = true
206-
# pop until we are at the start of the strongly connected component
207-
while repeat
208-
w = pop!(stack)
209-
onstack[w] = false
210-
lowlink[w] = ids[eq]
211-
# put `w` in current component
212-
push!(component, w)
213-
repeat = w != eq
214-
end
215-
push!(components, sort!(component))
216-
end
217-
return id
152+
cmog = DiCMOBiGraph(g, assign === nothing ? Base.OneTo(nsrcs(g)) : assign)
153+
sccs = Graphs.strongly_connected_components(cmog)
154+
foreach(sort!, sccs)
155+
return sccs
218156
end
219157

220158
function sorted_incidence_matrix(sys, val=true; only_algeqs=false, only_algvars=false)
@@ -295,7 +233,7 @@ end
295233
function inverse_mapping(assign)
296234
invassign = zeros(Int, length(assign))
297235
for (i, eq) in enumerate(assign)
298-
eq <= 0 && continue
236+
eq === unassigned && continue
299237
invassign[eq] = i
300238
end
301239
return invassign

src/systems/systemstructure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ Base.@kwdef struct SystemStructure
8080
algeqs::BitVector
8181
graph::BipartiteGraph{Int,Vector{Vector{Int}},Int,Nothing}
8282
solvable_graph::BipartiteGraph{Int,Vector{Vector{Int}},Int,Nothing}
83-
assign::Vector{Int}
83+
assign::Vector{Union{Int, Unassigned}}
8484
inv_assign::Vector{Int}
8585
scc::Vector{Vector{Int}}
8686
partitions::Vector{SystemPartition}

test/structural_transformation/index_reduction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pendulum = ODESystem(eqs, t, [x, y, w, z, T], [L, g], name=:pendulum)
3434
pendulum = initialize_system_structure(pendulum)
3535
sss = structure(pendulum)
3636
@unpack graph, fullvars, varassoc = sss
37-
@test StructuralTransformations.matching(sss, varassoc .== 0) == map(x -> x == 0 ? StructuralTransformations.UNASSIGNED : x, [1, 2, 3, 4, 0, 0, 0, 0, 0])
37+
@test StructuralTransformations.matching(sss, varassoc .== 0) == map(x -> x == 0 ? StructuralTransformations.unassigned : x, [1, 2, 3, 4, 0, 0, 0, 0, 0])
3838

3939
sys, assign, eqassoc = StructuralTransformations.pantelides!(pendulum)
4040
sss = structure(sys)

0 commit comments

Comments
 (0)