Skip to content

Commit 09d2ed3

Browse files
authored
Merge pull request #1355 from Keno/kf/cmographs
Refactor tearing
2 parents 6cfb1af + 384d869 commit 09d2ed3

File tree

11 files changed

+357
-383
lines changed

11 files changed

+357
-383
lines changed

src/bipartite_graph.jl

Lines changed: 128 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module BipartiteGraphs
22

33
export BipartiteEdge, BipartiteGraph, DiCMOBiGraph, Unassigned, unassigned,
4-
Matching
4+
Matching, ResidualCMOGraph, InducedCondensationGraph, maximal_matching,
5+
construct_augmenting_path!
56

67
export 𝑠vertices, 𝑑vertices, has_𝑠vertex, has_𝑑vertex, 𝑠neighbors, 𝑑neighbors,
78
𝑠edges, 𝑑edges, nsrcs, ndsts, SRC, DST, set_neighbors!, invview,
@@ -18,46 +19,69 @@ struct Unassigned
1819
global unassigned
1920
const unassigned = Unassigned.instance
2021
end
22+
# Behaves as a scalar
23+
Base.length(u::Unassigned) = 1
24+
Base.size(u::Unassigned) = ()
25+
Base.iterate(u::Unassigned) = (unassigned, nothing)
26+
Base.iterate(u::Unassigned, state) = nothing
2127

22-
struct Matching{V<:AbstractVector{<:Union{Unassigned, Int}}} <: AbstractVector{Union{Unassigned, Int}}
28+
Base.show(io::IO, ::Unassigned) =
29+
printstyled(io, "u"; color=:light_black)
30+
31+
struct Matching{U #=> :Unassigned =#, V<:AbstractVector} <: AbstractVector{Union{U, Int}}
2332
match::V
2433
inv_match::Union{Nothing, V}
2534
end
26-
Matching(v::V) where {V<:AbstractVector{<:Union{Unassigned, Int}}} =
27-
Matching{V}(v, nothing)
28-
Matching(m::Int) = Matching(Union{Int, Unassigned}[unassigned for _ = 1:m], nothing)
35+
# These constructors work around https://github.com/JuliaLang/julia/issues/41948
36+
function Matching{V}(m::Matching) where {V}
37+
eltype(m) === Union{V, Int} && return M
38+
VUT = typeof(similar(m.match, Union{V, Int}))
39+
Matching{V}(convert(VUT, m.match),
40+
m.inv_match === nothing ? nothing : convert(VUT, m.inv_match))
41+
end
2942
Matching(m::Matching) = m
43+
Matching{U}(v::V) where {U, V<:AbstractVector} = Matching{U, V}(v, nothing)
44+
Matching{U}(v::V, iv::Union{V, Nothing}) where {U, V<:AbstractVector} = Matching{U, V}(v, iv)
45+
Matching(v::V) where {U, V<:AbstractVector{Union{U, Int}}} =
46+
Matching{@isdefined(U) ? U : Unassigned, V}(v, nothing)
47+
Matching(m::Int) = Matching{Unassigned}(Union{Int, Unassigned}[unassigned for _ = 1:m], nothing)
3048

3149
Base.size(m::Matching) = Base.size(m.match)
3250
Base.getindex(m::Matching, i::Integer) = m.match[i]
3351
Base.iterate(m::Matching, state...) = iterate(m.match, state...)
34-
function Base.setindex!(m::Matching, v::Integer, i::Integer)
52+
Base.copy(m::Matching) = Matching(copy(m.match), m.inv_match === nothing ? nothing : copy(m.inv_match))
53+
function Base.setindex!(m::Matching{U}, v::Union{Integer, U}, i::Integer) where {U}
3554
if m.inv_match !== nothing
36-
m.inv_match[v] = i
55+
oldv = m.match[i]
56+
isa(oldv, Int) && (m.inv_match[oldv] = unassigned)
57+
isa(v, Int) && (m.inv_match[v] = i)
3758
end
3859
return m.match[i] = v
3960
end
4061

41-
function Base.push!(m::Matching, v::Union{Integer, Unassigned})
62+
function Base.push!(m::Matching{U}, v::Union{Integer, U}) where {U}
4263
push!(m.match, v)
4364
if v !== unassigned && m.inv_match !== nothing
4465
m.inv_match[v] = length(m.match)
4566
end
4667
end
4768

48-
function complete(m::Matching)
69+
function complete(m::Matching{U}) where {U}
4970
m.inv_match !== nothing && return m
50-
inv_match = Union{Unassigned, Int}[unassigned for _ = 1:length(m.match)]
71+
inv_match = Union{U, Int}[unassigned for _ = 1:length(m.match)]
5172
for (i, eq) in enumerate(m.match)
52-
eq === unassigned && continue
73+
isa(eq, Int) || continue
5374
inv_match[eq] = i
5475
end
55-
return Matching(collect(m.match), inv_match)
76+
return Matching{U}(collect(m.match), inv_match)
5677
end
5778

58-
function invview(m::Matching)
79+
@noinline require_complete(m::Matching) =
5980
m.inv_match === nothing && throw(ArgumentError("Backwards matching not defined. `complete` the matching first."))
60-
return Matching(m.inv_match, m.match)
81+
82+
function invview(m::Matching{U, V}) where {U, V}
83+
require_complete(m)
84+
return Matching{U, V}(m.inv_match, m.match)
6185
end
6286

6387
###
@@ -121,6 +145,26 @@ mutable struct BipartiteGraph{I<:Integer, M} <: Graphs.AbstractGraph{I}
121145
metadata::M
122146
end
123147
BipartiteGraph(ne::Integer, fadj::AbstractVector, badj::Union{AbstractVector,Integer}=maximum(maximum, fadj); metadata=nothing) = BipartiteGraph(ne, fadj, badj, metadata)
148+
BipartiteGraph(fadj::AbstractVector, badj::Union{AbstractVector,Integer}=maximum(maximum, fadj); metadata=nothing) =
149+
BipartiteGraph(mapreduce(length, +, fadj; init=0), fadj, badj, metadata)
150+
151+
@noinline require_complete(g::BipartiteGraph) = g.badjlist isa AbstractVector || throw(ArgumentError("The graph has no back edges. Use `complete`."))
152+
153+
function invview(g::BipartiteGraph)
154+
require_complete(g)
155+
BipartiteGraph(g.ne, g.badjlist, g.fadjlist)
156+
end
157+
158+
function complete(g::BipartiteGraph{I}) where {I}
159+
isa(g.badjlist, AbstractVector) && return g
160+
badjlist = Vector{I}[Vector{I}() for _ in 1:g.badjlist]
161+
for (s, l) in enumerate(g.fadjlist)
162+
for d in l
163+
push!(badjlist[d], s)
164+
end
165+
end
166+
BipartiteGraph(g.ne, g.fadjlist, badjlist)
167+
end
124168

125169
"""
126170
```julia
@@ -147,6 +191,7 @@ function BipartiteGraph(nsrcs::T, ndsts::T, backedge::Val{B}=Val(true); metadata
147191
BipartiteGraph(0, fadjlist, badjlist, metadata)
148192
end
149193

194+
Base.copy(bg::BipartiteGraph) = BipartiteGraph(bg.ne, copy(bg.fadjlist), copy(bg.badjlist), deepcopy(bg.metadata))
150195
Base.eltype(::Type{<:BipartiteGraph{I}}) where I = I
151196
function Base.empty!(g::BipartiteGraph)
152197
foreach(empty!, g.fadjlist)
@@ -159,8 +204,6 @@ function Base.empty!(g::BipartiteGraph)
159204
end
160205
Base.length(::BipartiteGraph) = error("length is not well defined! Use `ne` or `nv`.")
161206

162-
@noinline throw_no_back_edges() = throw(ArgumentError("The graph has no back edges."))
163-
164207
if isdefined(Graphs, :has_contiguous_vertices)
165208
Graphs.has_contiguous_vertices(::Type{<:BipartiteGraph}) = false
166209
end
@@ -172,7 +215,7 @@ has_𝑠vertex(g::BipartiteGraph, v::Integer) = v in 𝑠vertices(g)
172215
has_𝑑vertex(g::BipartiteGraph, v::Integer) = v in 𝑑vertices(g)
173216
𝑠neighbors(g::BipartiteGraph, i::Integer, with_metadata::Val{M}=Val(false)) where M = M ? zip(g.fadjlist[i], g.metadata[i]) : g.fadjlist[i]
174217
function 𝑑neighbors(g::BipartiteGraph, j::Integer, with_metadata::Val{M}=Val(false)) where M
175-
g.badjlist isa AbstractVector || throw_no_back_edges()
218+
require_complete(g)
176219
M ? zip(g.badjlist[j], (g.metadata[i][j] for i in g.badjlist[j])) : g.badjlist[j]
177220
end
178221
Graphs.ne(g::BipartiteGraph) = g.ne
@@ -185,7 +228,53 @@ ndsts(g::BipartiteGraph) = length(𝑑vertices(g))
185228
function Graphs.has_edge(g::BipartiteGraph, edge::BipartiteEdge)
186229
@unpack src, dst = edge
187230
(src in 𝑠vertices(g) && dst in 𝑑vertices(g)) || return false # edge out of bounds
188-
insorted(𝑠neighbors(src), dst)
231+
insorted(dst, 𝑠neighbors(g, src))
232+
end
233+
Base.in(edge::BipartiteEdge, g::BipartiteGraph) = Graphs.has_edge(g, edge)
234+
235+
### Maximal matching
236+
"""
237+
construct_augmenting_path!(m::Matching, g::BipartiteGraph, vsrc, dstfilter, vcolor=falses(ndsts(g)), ecolor=falses(nsrcs(g))) -> path_found::Bool
238+
239+
Try to construct an augmenting path in matching and if such a path is found,
240+
update the matching accordingly.
241+
"""
242+
function construct_augmenting_path!(matching::Matching, g::BipartiteGraph, vsrc, dstfilter, dcolor=falses(ndsts(g)), scolor=falses(nsrcs(g)))
243+
scolor[vsrc] = true
244+
245+
# if a `vdst` is unassigned and the edge `vsrc <=> vdst` exists
246+
for vdst in 𝑠neighbors(g, vsrc)
247+
if dstfilter(vdst) && matching[vdst] === unassigned
248+
matching[vdst] = vsrc
249+
return true
250+
end
251+
end
252+
253+
# for every `vsrc` such that edge `vsrc <=> vdst` exists and `vdst` is uncolored
254+
for vdst in 𝑠neighbors(g, vsrc)
255+
(dstfilter(vdst) && !dcolor[vdst]) || continue
256+
dcolor[vdst] = true
257+
if construct_augmenting_path!(matching, g, matching[vdst], dstfilter, dcolor, scolor)
258+
matching[vdst] = vsrc
259+
return true
260+
end
261+
end
262+
return false
263+
end
264+
265+
"""
266+
maximal_matching(g::BipartiteGraph, [srcfilter], [dstfilter])
267+
268+
For a bipartite graph `g`, construct a maximal matching of destination to source
269+
vertices, subject to the constraint that vertices for which `srcfilter` or `dstfilter`,
270+
return `false` may not be matched.
271+
"""
272+
function maximal_matching(g::BipartiteGraph, srcfilter=vsrc->true, dstfilter=vdst->true)
273+
matching = Matching(ndsts(g))
274+
foreach(Iterators.filter(srcfilter, 𝑠vertices(g))) do vsrc
275+
construct_augmenting_path!(matching, g, vsrc, dstfilter)
276+
end
277+
return matching
189278
end
190279

191280
###
@@ -333,6 +422,14 @@ The resulting graph has a few desirable properties. In particular, this graph
333422
is acyclic if and only if the induced directed graph on the original bipartite
334423
graph is acyclic.
335424
425+
# Hypergraph interpretation
426+
427+
Consider the bipartite graph `B` as the incidence graph of some hypergraph `H`.
428+
Note that a maching `M` on `B` in the above sense is equivalent to determining
429+
an (1,n)-orientation on the hypergraph (i.e. each directed hyperedge has exactly
430+
one head, but any arbitrary number of tails). In this setting, this is simply
431+
the graph formed by expanding each directed hyperedge into `n` ordinary edges
432+
between the same vertices.
336433
"""
337434
mutable struct DiCMOBiGraph{Transposed, I, G<:BipartiteGraph{I}, M <: Matching} <: Graphs.AbstractGraph{I}
338435
graph::G
@@ -348,6 +445,9 @@ function DiCMOBiGraph{Transposed}(g::BipartiteGraph, m::M) where {Transposed, M}
348445
DiCMOBiGraph{Transposed}(g, missing, m)
349446
end
350447

448+
invview(g::DiCMOBiGraph{Transposed}) where {Transposed} =
449+
DiCMOBiGraph{!Transposed}(invview(g.graph), g.ne, invview(g.matching))
450+
351451
Graphs.is_directed(::Type{<:DiCMOBiGraph}) = true
352452
Graphs.nv(g::DiCMOBiGraph{Transposed}) where {Transposed} = Transposed ? ndsts(g.graph) : nsrcs(g.graph)
353453
Graphs.vertices(g::DiCMOBiGraph{Transposed}) where {Transposed} = Transposed ? 𝑑vertices(g.graph) : 𝑠vertices(g.graph)
@@ -360,6 +460,7 @@ struct CMONeighbors{Transposed, V}
360460
end
361461

362462
Graphs.outneighbors(g::DiCMOBiGraph{false}, v) = CMONeighbors{false}(g, v)
463+
Graphs.inneighbors(g::DiCMOBiGraph{false}, v) = inneighbors(invview(g), v)
363464
Base.iterate(c::CMONeighbors{false}) = iterate(c, (c.g.graph.fadjlist[c.v],))
364465
function Base.iterate(c::CMONeighbors{false}, (l, state...))
365466
while true
@@ -376,14 +477,17 @@ function Base.iterate(c::CMONeighbors{false}, (l, state...))
376477
return vsrc, (l, r[2])
377478
end
378479
end
480+
Base.length(c::CMONeighbors{false}) = count(_->true, c)
379481

380-
lift(f, x) = (x === unassigned || isnothing(x)) ? nothing : f(x)
482+
liftint(f, x) = (!isa(x, Int)) ? nothing : f(x)
483+
liftnothing(f, x) = x === nothing ? nothing : f(x)
381484

382485
_vsrc(c::CMONeighbors{true}) = c.g.matching[c.v]
383-
_neighbors(c::CMONeighbors{true}) = lift(vsrc->c.g.graph.fadjlist[vsrc], _vsrc(c))
384-
Base.length(c::CMONeighbors{true}) = something(lift(length, _neighbors(c)), 1) - 1
486+
_neighbors(c::CMONeighbors{true}) = liftint(vsrc->c.g.graph.fadjlist[vsrc], _vsrc(c))
487+
Base.length(c::CMONeighbors{true}) = something(liftnothing(length, _neighbors(c)), 1) - 1
385488
Graphs.inneighbors(g::DiCMOBiGraph{true}, v) = CMONeighbors{true}(g, v)
386-
Base.iterate(c::CMONeighbors{true}) = lift(ns->iterate(c, (ns,)), _neighbors(c))
489+
Graphs.outneighbors(g::DiCMOBiGraph{true}, v) = outneighbors(invview(g), v)
490+
Base.iterate(c::CMONeighbors{true}) = liftnothing(ns->iterate(c, (ns,)), _neighbors(c))
387491
function Base.iterate(c::CMONeighbors{true}, (l, state...))
388492
while true
389493
r = iterate(l, state...)
@@ -396,16 +500,15 @@ function Base.iterate(c::CMONeighbors{true}, (l, state...))
396500
end
397501
end
398502

503+
399504
_edges(g::DiCMOBiGraph{Transposed}) where Transposed = Transposed ?
400505
((w=>v for w in inneighbors(g, v)) for v in vertices(g)) :
401506
((v=>w for w in outneighbors(g, v)) for v in vertices(g))
402-
_count(c::CMONeighbors{true}) = length(c)
403-
_count(c::CMONeighbors{false}) = count(_->true, c)
404507

405508
Graphs.edges(g::DiCMOBiGraph) = (Graphs.SimpleEdge(p) for p in Iterators.flatten(_edges(g)))
406509
function Graphs.ne(g::DiCMOBiGraph)
407510
if g.ne === missing
408-
g.ne = mapreduce(x->_count(x.iter), +, _edges(g))
511+
g.ne = mapreduce(x->length(x.iter), +, _edges(g))
409512
end
410513
return g.ne
411514
end

src/structural_transformation/bipartite_tearing/modia_tearing.jl

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,30 @@ function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int}
4242
end
4343
end
4444

45-
vSolved = filter(v->G.matching[v] !== unassigned, topological_sort(ict))
46-
inv_matching = Union{Missing, Int}[missing for _ = 1:nv(G)]
47-
for (v, eq) in pairs(G.matching)
48-
eq === unassigned && continue
49-
inv_matching[v] = eq
45+
return ict
46+
end
47+
48+
"""
49+
tear_graph_modia(sys) -> sys
50+
51+
Tear the bipartite graph in a system. End users are encouraged to call [`structural_simplify`](@ref)
52+
instead, which calls this function internally.
53+
"""
54+
function tear_graph_modia(graph::BipartiteGraph, solvable_graph::BipartiteGraph; varfilter=v->true, eqfilter=eq->true)
55+
var_eq_matching = complete(maximal_matching(graph, eqfilter, varfilter))
56+
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, var_eq_matching)
57+
58+
for vars in var_sccs
59+
filtered_vars = filter(varfilter, vars)
60+
ieqs = Int[var_eq_matching[v] for v in filtered_vars if var_eq_matching[v] !== unassigned]
61+
62+
ict = IncrementalCycleTracker(DiCMOBiGraph{true}(graph); dir=:in)
63+
tearEquations!(ict, solvable_graph.fadjlist, ieqs, filtered_vars)
64+
65+
for var in vars
66+
var_eq_matching[var] = ict.graph.matching[var]
67+
end
5068
end
51-
eSolved = getindex.(Ref(inv_matching), vSolved)
52-
vTear = setdiff(vs, vSolved)
53-
eResidue = setdiff(es, eSolved)
54-
return (eSolved, vSolved, eResidue, vTear)
69+
70+
return var_eq_matching
5571
end

0 commit comments

Comments
 (0)