Skip to content

Commit 0c6e5c9

Browse files
authored
Merge pull request #1347 from Keno/kf/graphscleanup1
BipartiteGraph: Clean up Graphs.jl integration
2 parents 571eff7 + e3c8951 commit 0c6e5c9

File tree

4 files changed

+50
-33
lines changed

4 files changed

+50
-33
lines changed

src/bipartite_graph.jl

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module BipartiteGraphs
33
export BipartiteEdge, BipartiteGraph, DiCMOBiGraph, Unassigned, unassigned
44

55
export 𝑠vertices, 𝑑vertices, has_𝑠vertex, has_𝑑vertex, 𝑠neighbors, 𝑑neighbors,
6-
𝑠edges, 𝑑edges, nsrcs, ndsts, SRC, DST
6+
𝑠edges, 𝑑edges, nsrcs, ndsts, SRC, DST, set_neighbors!
77

88
using DocStringExtensions
99
using UnPack
@@ -20,7 +20,7 @@ end
2020
###
2121
### Edges & Vertex
2222
###
23-
@enum VertType SRC DST ALL
23+
@enum VertType SRC DST
2424

2525
struct BipartiteEdge{I<:Integer} <: Graphs.AbstractEdge{I}
2626
src::I
@@ -189,10 +189,17 @@ function Graphs.add_vertex!(g::BipartiteGraph{T}, type::VertType) where T
189189
return true # vertex successfully added
190190
end
191191

192+
function set_neighbors!(g::BipartiteGraph, i::Integer, new_neighbors::AbstractVector)
193+
old_nneighbors = length(g.fadjlist[i])
194+
new_nneighbors = length(new_neighbors)
195+
g.fadjlist[i] = new_neighbors
196+
g.ne += new_nneighbors - old_nneighbors
197+
end
198+
192199
###
193200
### Edges iteration
194201
###
195-
Graphs.edges(g::BipartiteGraph) = BipartiteEdgeIter(g, Val(ALL))
202+
Graphs.edges(g::BipartiteGraph) = BipartiteEdgeIter(g, Val(SRC))
196203
𝑠edges(g::BipartiteGraph) = BipartiteEdgeIter(g, Val(SRC))
197204
𝑑edges(g::BipartiteGraph) = BipartiteEdgeIter(g, Val(DST))
198205

@@ -202,8 +209,6 @@ struct BipartiteEdgeIter{T,G} <: Graphs.AbstractEdgeIter
202209
end
203210

204211
Base.length(it::BipartiteEdgeIter) = ne(it.g)
205-
Base.length(it::BipartiteEdgeIter{ALL}) = 2ne(it.g)
206-
207212
Base.eltype(it::BipartiteEdgeIter) = edgetype(it.g)
208213

209214
function Base.iterate(it::BipartiteEdgeIter{SRC,<:BipartiteGraph{T}}, state=(1, 1, SRC)) where T
@@ -247,21 +252,6 @@ function Base.iterate(it::BipartiteEdgeIter{DST,<:BipartiteGraph{T}}, state=(1,
247252
return nothing
248253
end
249254

250-
function Base.iterate(it::BipartiteEdgeIter{ALL,<:BipartiteGraph}, state=nothing)
251-
if state === nothing
252-
ss = iterate((@set it.type = Val(SRC)))
253-
elseif state[3] === SRC
254-
ss = iterate((@set it.type = Val(SRC)), state)
255-
elseif state[3] == DST
256-
ss = iterate((@set it.type = Val(DST)), state)
257-
end
258-
if ss === nothing && state[3] == SRC
259-
return iterate((@set it.type = Val(DST)))
260-
else
261-
return ss
262-
end
263-
end
264-
265255
###
266256
### Utils
267257
###
@@ -301,13 +291,20 @@ is acyclic if and only if the induced directed graph on the original bipartite
301291
graph is acyclic.
302292
303293
"""
304-
struct DiCMOBiGraph{Transposed, I, G<:BipartiteGraph{I}, M} <: Graphs.AbstractGraph{I}
294+
mutable struct DiCMOBiGraph{Transposed, I, G<:BipartiteGraph{I}, M} <: Graphs.AbstractGraph{I}
305295
graph::G
296+
ne::Union{Missing, Int}
306297
matching::M
307-
DiCMOBiGraph{Transposed}(g::G, m::M) where {Transposed, I, G<:BipartiteGraph{I}, M} =
308-
new{Transposed, I, G, M}(g, m)
298+
DiCMOBiGraph{Transposed}(g::G, ne::Union{Missing, Int}, m::M) where {Transposed, I, G<:BipartiteGraph{I}, M} =
299+
new{Transposed, I, G, M}(g, ne, m)
300+
end
301+
function DiCMOBiGraph{Transposed}(g::BipartiteGraph) where {Transposed}
302+
DiCMOBiGraph{Transposed}(g, 0, Union{Unassigned, Int}[unassigned for i = 1:ndsts(g)])
303+
end
304+
function DiCMOBiGraph{Transposed}(g::BipartiteGraph, m::M) where {Transposed, M}
305+
DiCMOBiGraph{Transposed}(g, missing, m)
309306
end
310-
DiCMOBiGraph{Transposed}(g::BipartiteGraph) where {Transposed} = DiCMOBiGraph{Transposed}(g, Union{Unassigned, Int}[unassigned for i = 1:ndsts(g)])
307+
311308
Graphs.is_directed(::Type{<:DiCMOBiGraph}) = true
312309
Graphs.nv(g::DiCMOBiGraph{Transposed}) where {Transposed} = Transposed ? ndsts(g.graph) : nsrcs(g.graph)
313310
Graphs.vertices(g::DiCMOBiGraph{Transposed}) where {Transposed} = Transposed ? 𝑑vertices(g.graph) : 𝑠vertices(g.graph)
@@ -318,6 +315,7 @@ struct CMONeighbors{Transposed, V}
318315
CMONeighbors{Transposed}(g::DiCMOBiGraph{Transposed}, v::V) where {Transposed, V} =
319316
new{Transposed, V}(g, v)
320317
end
318+
321319
Graphs.outneighbors(g::DiCMOBiGraph{false}, v) = CMONeighbors{false}(g, v)
322320
Base.iterate(c::CMONeighbors{false}) = iterate(c, (c.g.graph.fadjlist[c.v],))
323321
function Base.iterate(c::CMONeighbors{false}, (l, state...))
@@ -336,12 +334,13 @@ function Base.iterate(c::CMONeighbors{false}, (l, state...))
336334
end
337335
end
338336

337+
lift(f, x) = (x === unassigned || isnothing(x)) ? nothing : f(x)
338+
339+
_vsrc(c::CMONeighbors{true}) = c.g.matching[c.v]
340+
_neighbors(c::CMONeighbors{true}) = lift(vsrc->c.g.graph.fadjlist[vsrc], _vsrc(c))
341+
Base.length(c::CMONeighbors{true}) = something(lift(length, _neighbors(c)), 1) - 1
339342
Graphs.inneighbors(g::DiCMOBiGraph{true}, v) = CMONeighbors{true}(g, v)
340-
function Base.iterate(c::CMONeighbors{true})
341-
vsrc = c.g.matching[c.v]
342-
vsrc === unassigned && return nothing
343-
iterate(c, (c.g.graph.fadjlist[vsrc],))
344-
end
343+
Base.iterate(c::CMONeighbors{true}) = lift(ns->iterate(c, (ns,)), _neighbors(c))
345344
function Base.iterate(c::CMONeighbors{true}, (l, state...))
346345
while true
347346
r = iterate(l, state...)
@@ -354,4 +353,21 @@ function Base.iterate(c::CMONeighbors{true}, (l, state...))
354353
end
355354
end
356355

356+
_edges(g::DiCMOBiGraph{Transposed}) where Transposed = Transposed ?
357+
((w=>v for w in inneighbors(g, v)) for v in vertices(g)) :
358+
((v=>w for w in outneighbors(g, v)) for v in vertices(g))
359+
_count(c::CMONeighbors{true}) = length(c)
360+
_count(c::CMONeighbors{false}) = count(_->true, c)
361+
362+
Graphs.edges(g::DiCMOBiGraph) = (Graphs.SimpleEdge(p) for p in Iterators.flatten(_edges(g)))
363+
function Graphs.ne(g::DiCMOBiGraph)
364+
if g.ne === missing
365+
g.ne = mapreduce(x->_count(x.iter), +, _edges(g))
366+
end
367+
return g.ne
368+
end
369+
370+
Graphs.has_edge(g::DiCMOBiGraph{true}, a, b) = a in inneighbors(g, b)
371+
Graphs.has_edge(g::DiCMOBiGraph{false}, a, b) = b in outneighbors(g, a)
372+
357373
end # module

src/structural_transformation/bipartite_tearing/modia_tearing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int}
3535
if G.matching[vj] === unassigned && (vj in vActive)
3636
r = add_edge_checked!(ict, Iterators.filter(!=(vj), 𝑠neighbors(G.graph, eq)), vj) do G
3737
G.matching[vj] = eq
38+
G.ne += length(𝑠neighbors(G.graph, eq)) - 1
3839
end
3940
r && break
4041
end

src/systems/alias_elimination.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ struct AliasGraph <: AbstractDict{Int, Pair{Int, Int}}
127127
end
128128
end
129129

130+
Base.length(ag::AliasGraph) = length(ag.eliminated)
131+
130132
function Base.getindex(ag::AliasGraph, i::Integer)
131133
r = ag.aliasto[i]
132134
r === nothing && throw(KeyError(i))
@@ -281,7 +283,7 @@ function alias_eliminate_graph!(graph, varassoc, mm_orig::SparseMatrixCLIL)
281283

282284
# Step 3: Reflect our update decitions back into the graph
283285
for (ei, e) in enumerate(mm.nzrows)
284-
graph.fadjlist[e] = mm.row_cols[ei]
286+
set_neighbors!(graph, e, mm.row_cols[ei])
285287
end
286288

287289
return ag, mm

test/structural_transformation/utils.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ sss = structure(sys)
2727
@test nv(solvable_graph) == 9 + 5
2828
@test varassoc == [0, 0, 0, 0, 1, 2, 3, 4, 0]
2929

30-
se = collect(StructuralTransformations.𝑠edges(graph))
30+
se = collect(StructuralTransformations.edges(graph))
3131
@test se == mapreduce(vcat, enumerate(graph.fadjlist)) do (s, d)
3232
StructuralTransformations.BipartiteEdge.(s, d)
3333
end
34-
@test_throws ArgumentError collect(StructuralTransformations.𝑑edges(graph))
35-
@test_throws ArgumentError collect(StructuralTransformations.edges(graph))

0 commit comments

Comments
 (0)