Skip to content

Commit e3c8951

Browse files
committed
BipartiteGraph: Clean up Graphs.jl integration
Removes the `ALL` option for BipartiteGraph edges iteration. This option makes little sense. The fadjlist and badjlist represent the same edges just with two different index structures. It is up to the BipartiteGraph to keep them internally consistent, but there's no reason to allow iteration over both from outside - they should never be inconsistent. Also add a few more Graphs.jl integration to allow more interesting graph algorithms to be run on our graph as we explore.
1 parent 571eff7 commit e3c8951

File tree

4 files changed

+50
-33
lines changed

4 files changed

+50
-33
lines changed

src/bipartite_graph.jl

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module BipartiteGraphs
33
export BipartiteEdge, BipartiteGraph, DiCMOBiGraph, Unassigned, unassigned
44

55
export 𝑠vertices, 𝑑vertices, has_𝑠vertex, has_𝑑vertex, 𝑠neighbors, 𝑑neighbors,
6-
𝑠edges, 𝑑edges, nsrcs, ndsts, SRC, DST
6+
𝑠edges, 𝑑edges, nsrcs, ndsts, SRC, DST, set_neighbors!
77

88
using DocStringExtensions
99
using UnPack
@@ -20,7 +20,7 @@ end
2020
###
2121
### Edges & Vertex
2222
###
23-
@enum VertType SRC DST ALL
23+
@enum VertType SRC DST
2424

2525
struct BipartiteEdge{I<:Integer} <: Graphs.AbstractEdge{I}
2626
src::I
@@ -189,10 +189,17 @@ function Graphs.add_vertex!(g::BipartiteGraph{T}, type::VertType) where T
189189
return true # vertex successfully added
190190
end
191191

192+
function set_neighbors!(g::BipartiteGraph, i::Integer, new_neighbors::AbstractVector)
193+
old_nneighbors = length(g.fadjlist[i])
194+
new_nneighbors = length(new_neighbors)
195+
g.fadjlist[i] = new_neighbors
196+
g.ne += new_nneighbors - old_nneighbors
197+
end
198+
192199
###
193200
### Edges iteration
194201
###
195-
Graphs.edges(g::BipartiteGraph) = BipartiteEdgeIter(g, Val(ALL))
202+
Graphs.edges(g::BipartiteGraph) = BipartiteEdgeIter(g, Val(SRC))
196203
𝑠edges(g::BipartiteGraph) = BipartiteEdgeIter(g, Val(SRC))
197204
𝑑edges(g::BipartiteGraph) = BipartiteEdgeIter(g, Val(DST))
198205

@@ -202,8 +209,6 @@ struct BipartiteEdgeIter{T,G} <: Graphs.AbstractEdgeIter
202209
end
203210

204211
Base.length(it::BipartiteEdgeIter) = ne(it.g)
205-
Base.length(it::BipartiteEdgeIter{ALL}) = 2ne(it.g)
206-
207212
Base.eltype(it::BipartiteEdgeIter) = edgetype(it.g)
208213

209214
function Base.iterate(it::BipartiteEdgeIter{SRC,<:BipartiteGraph{T}}, state=(1, 1, SRC)) where T
@@ -247,21 +252,6 @@ function Base.iterate(it::BipartiteEdgeIter{DST,<:BipartiteGraph{T}}, state=(1,
247252
return nothing
248253
end
249254

250-
function Base.iterate(it::BipartiteEdgeIter{ALL,<:BipartiteGraph}, state=nothing)
251-
if state === nothing
252-
ss = iterate((@set it.type = Val(SRC)))
253-
elseif state[3] === SRC
254-
ss = iterate((@set it.type = Val(SRC)), state)
255-
elseif state[3] == DST
256-
ss = iterate((@set it.type = Val(DST)), state)
257-
end
258-
if ss === nothing && state[3] == SRC
259-
return iterate((@set it.type = Val(DST)))
260-
else
261-
return ss
262-
end
263-
end
264-
265255
###
266256
### Utils
267257
###
@@ -301,13 +291,20 @@ is acyclic if and only if the induced directed graph on the original bipartite
301291
graph is acyclic.
302292
303293
"""
304-
struct DiCMOBiGraph{Transposed, I, G<:BipartiteGraph{I}, M} <: Graphs.AbstractGraph{I}
294+
mutable struct DiCMOBiGraph{Transposed, I, G<:BipartiteGraph{I}, M} <: Graphs.AbstractGraph{I}
305295
graph::G
296+
ne::Union{Missing, Int}
306297
matching::M
307-
DiCMOBiGraph{Transposed}(g::G, m::M) where {Transposed, I, G<:BipartiteGraph{I}, M} =
308-
new{Transposed, I, G, M}(g, m)
298+
DiCMOBiGraph{Transposed}(g::G, ne::Union{Missing, Int}, m::M) where {Transposed, I, G<:BipartiteGraph{I}, M} =
299+
new{Transposed, I, G, M}(g, ne, m)
300+
end
301+
function DiCMOBiGraph{Transposed}(g::BipartiteGraph) where {Transposed}
302+
DiCMOBiGraph{Transposed}(g, 0, Union{Unassigned, Int}[unassigned for i = 1:ndsts(g)])
303+
end
304+
function DiCMOBiGraph{Transposed}(g::BipartiteGraph, m::M) where {Transposed, M}
305+
DiCMOBiGraph{Transposed}(g, missing, m)
309306
end
310-
DiCMOBiGraph{Transposed}(g::BipartiteGraph) where {Transposed} = DiCMOBiGraph{Transposed}(g, Union{Unassigned, Int}[unassigned for i = 1:ndsts(g)])
307+
311308
Graphs.is_directed(::Type{<:DiCMOBiGraph}) = true
312309
Graphs.nv(g::DiCMOBiGraph{Transposed}) where {Transposed} = Transposed ? ndsts(g.graph) : nsrcs(g.graph)
313310
Graphs.vertices(g::DiCMOBiGraph{Transposed}) where {Transposed} = Transposed ? 𝑑vertices(g.graph) : 𝑠vertices(g.graph)
@@ -318,6 +315,7 @@ struct CMONeighbors{Transposed, V}
318315
CMONeighbors{Transposed}(g::DiCMOBiGraph{Transposed}, v::V) where {Transposed, V} =
319316
new{Transposed, V}(g, v)
320317
end
318+
321319
Graphs.outneighbors(g::DiCMOBiGraph{false}, v) = CMONeighbors{false}(g, v)
322320
Base.iterate(c::CMONeighbors{false}) = iterate(c, (c.g.graph.fadjlist[c.v],))
323321
function Base.iterate(c::CMONeighbors{false}, (l, state...))
@@ -336,12 +334,13 @@ function Base.iterate(c::CMONeighbors{false}, (l, state...))
336334
end
337335
end
338336

337+
lift(f, x) = (x === unassigned || isnothing(x)) ? nothing : f(x)
338+
339+
_vsrc(c::CMONeighbors{true}) = c.g.matching[c.v]
340+
_neighbors(c::CMONeighbors{true}) = lift(vsrc->c.g.graph.fadjlist[vsrc], _vsrc(c))
341+
Base.length(c::CMONeighbors{true}) = something(lift(length, _neighbors(c)), 1) - 1
339342
Graphs.inneighbors(g::DiCMOBiGraph{true}, v) = CMONeighbors{true}(g, v)
340-
function Base.iterate(c::CMONeighbors{true})
341-
vsrc = c.g.matching[c.v]
342-
vsrc === unassigned && return nothing
343-
iterate(c, (c.g.graph.fadjlist[vsrc],))
344-
end
343+
Base.iterate(c::CMONeighbors{true}) = lift(ns->iterate(c, (ns,)), _neighbors(c))
345344
function Base.iterate(c::CMONeighbors{true}, (l, state...))
346345
while true
347346
r = iterate(l, state...)
@@ -354,4 +353,21 @@ function Base.iterate(c::CMONeighbors{true}, (l, state...))
354353
end
355354
end
356355

356+
_edges(g::DiCMOBiGraph{Transposed}) where Transposed = Transposed ?
357+
((w=>v for w in inneighbors(g, v)) for v in vertices(g)) :
358+
((v=>w for w in outneighbors(g, v)) for v in vertices(g))
359+
_count(c::CMONeighbors{true}) = length(c)
360+
_count(c::CMONeighbors{false}) = count(_->true, c)
361+
362+
Graphs.edges(g::DiCMOBiGraph) = (Graphs.SimpleEdge(p) for p in Iterators.flatten(_edges(g)))
363+
function Graphs.ne(g::DiCMOBiGraph)
364+
if g.ne === missing
365+
g.ne = mapreduce(x->_count(x.iter), +, _edges(g))
366+
end
367+
return g.ne
368+
end
369+
370+
Graphs.has_edge(g::DiCMOBiGraph{true}, a, b) = a in inneighbors(g, b)
371+
Graphs.has_edge(g::DiCMOBiGraph{false}, a, b) = b in outneighbors(g, a)
372+
357373
end # module

src/structural_transformation/bipartite_tearing/modia_tearing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int}
3535
if G.matching[vj] === unassigned && (vj in vActive)
3636
r = add_edge_checked!(ict, Iterators.filter(!=(vj), 𝑠neighbors(G.graph, eq)), vj) do G
3737
G.matching[vj] = eq
38+
G.ne += length(𝑠neighbors(G.graph, eq)) - 1
3839
end
3940
r && break
4041
end

src/systems/alias_elimination.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ struct AliasGraph <: AbstractDict{Int, Pair{Int, Int}}
127127
end
128128
end
129129

130+
Base.length(ag::AliasGraph) = length(ag.eliminated)
131+
130132
function Base.getindex(ag::AliasGraph, i::Integer)
131133
r = ag.aliasto[i]
132134
r === nothing && throw(KeyError(i))
@@ -281,7 +283,7 @@ function alias_eliminate_graph!(graph, varassoc, mm_orig::SparseMatrixCLIL)
281283

282284
# Step 3: Reflect our update decitions back into the graph
283285
for (ei, e) in enumerate(mm.nzrows)
284-
graph.fadjlist[e] = mm.row_cols[ei]
286+
set_neighbors!(graph, e, mm.row_cols[ei])
285287
end
286288

287289
return ag, mm

test/structural_transformation/utils.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ sss = structure(sys)
2727
@test nv(solvable_graph) == 9 + 5
2828
@test varassoc == [0, 0, 0, 0, 1, 2, 3, 4, 0]
2929

30-
se = collect(StructuralTransformations.𝑠edges(graph))
30+
se = collect(StructuralTransformations.edges(graph))
3131
@test se == mapreduce(vcat, enumerate(graph.fadjlist)) do (s, d)
3232
StructuralTransformations.BipartiteEdge.(s, d)
3333
end
34-
@test_throws ArgumentError collect(StructuralTransformations.𝑑edges(graph))
35-
@test_throws ArgumentError collect(StructuralTransformations.edges(graph))

0 commit comments

Comments
 (0)