Skip to content

Commit 6c1e78b

Browse files
committed
Fill out Graphs API a bit more
Just a few odds and ends I was missing in the API as I was playing with these new datastructures.
1 parent 9cc6f21 commit 6c1e78b

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

src/bipartite_graph.jl

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,12 @@ Matching(m::Matching) = m
3131
Base.size(m::Matching) = Base.size(m.match)
3232
Base.getindex(m::Matching, i::Integer) = m.match[i]
3333
Base.iterate(m::Matching, state...) = iterate(m.match, state...)
34-
function Base.setindex!(m::Matching, v::Integer, i::Integer)
34+
Base.copy(m::Matching) = Matching(copy(m.match), m.inv_match === nothing ? nothing : copy(m.inv_match))
35+
function Base.setindex!(m::Matching, v::Union{Integer, Unassigned}, i::Integer)
3536
if m.inv_match !== nothing
36-
m.inv_match[v] = i
37+
oldv = m.match[i]
38+
oldv !== unassigned && (m.inv_match[oldv] = unassigned)
39+
v !== unassigned && (m.inv_match[v] = i)
3740
end
3841
return m.match[i] = v
3942
end
@@ -55,8 +58,11 @@ function complete(m::Matching)
5558
return Matching(collect(m.match), inv_match)
5659
end
5760

58-
function invview(m::Matching)
61+
@noinline require_complete(m::Matching) =
5962
m.inv_match === nothing && throw(ArgumentError("Backwards matching not defined. `complete` the matching first."))
63+
64+
function invview(m::Matching)
65+
require_complete(m)
6066
return Matching(m.inv_match, m.match)
6167
end
6268

@@ -122,6 +128,23 @@ mutable struct BipartiteGraph{I<:Integer, M} <: Graphs.AbstractGraph{I}
122128
end
123129
BipartiteGraph(ne::Integer, fadj::AbstractVector, badj::Union{AbstractVector,Integer}=maximum(maximum, fadj); metadata=nothing) = BipartiteGraph(ne, fadj, badj, metadata)
124130

131+
@noinline require_complete(g::BipartiteGraph) = g.badjlist isa AbstractVector || throw(ArgumentError("The graph has no back edges. Use `complete`."))
132+
133+
function invview(g::BipartiteGraph)
134+
BipartiteGraph(g.ne, g.badjlist, g.fadjlist)
135+
end
136+
137+
function complete(g::BipartiteGraph{I}) where {I}
138+
isa(g.badjlist, AbstractVector) && return g
139+
badjlist = Vector{I}[Vector{I}() for _ in 1:g.badjlist]
140+
for (s, l) in enumerate(g.fadjlist)
141+
for d in l
142+
push!(badjlist[d], s)
143+
end
144+
end
145+
BipartiteGraph(g.ne, g.fadjlist, badjlist)
146+
end
147+
125148
"""
126149
```julia
127150
Base.isequal(bg1::BipartiteGraph{T}, bg2::BipartiteGraph{T}) where {T<:Integer}
@@ -147,6 +170,7 @@ function BipartiteGraph(nsrcs::T, ndsts::T, backedge::Val{B}=Val(true); metadata
147170
BipartiteGraph(0, fadjlist, badjlist, metadata)
148171
end
149172

173+
Base.copy(bg::BipartiteGraph) = BipartiteGraph(bg.ne, copy(bg.fadjlist), copy(bg.badjlist), deepcopy(bg.metadata))
150174
Base.eltype(::Type{<:BipartiteGraph{I}}) where I = I
151175
function Base.empty!(g::BipartiteGraph)
152176
foreach(empty!, g.fadjlist)
@@ -159,8 +183,6 @@ function Base.empty!(g::BipartiteGraph)
159183
end
160184
Base.length(::BipartiteGraph) = error("length is not well defined! Use `ne` or `nv`.")
161185

162-
@noinline throw_no_back_edges() = throw(ArgumentError("The graph has no back edges."))
163-
164186
if isdefined(Graphs, :has_contiguous_vertices)
165187
Graphs.has_contiguous_vertices(::Type{<:BipartiteGraph}) = false
166188
end
@@ -172,7 +194,7 @@ has_𝑠vertex(g::BipartiteGraph, v::Integer) = v in 𝑠vertices(g)
172194
has_𝑑vertex(g::BipartiteGraph, v::Integer) = v in 𝑑vertices(g)
173195
𝑠neighbors(g::BipartiteGraph, i::Integer, with_metadata::Val{M}=Val(false)) where M = M ? zip(g.fadjlist[i], g.metadata[i]) : g.fadjlist[i]
174196
function 𝑑neighbors(g::BipartiteGraph, j::Integer, with_metadata::Val{M}=Val(false)) where M
175-
g.badjlist isa AbstractVector || throw_no_back_edges()
197+
require_complete(g)
176198
M ? zip(g.badjlist[j], (g.metadata[i][j] for i in g.badjlist[j])) : g.badjlist[j]
177199
end
178200
Graphs.ne(g::BipartiteGraph) = g.ne
@@ -348,6 +370,9 @@ function DiCMOBiGraph{Transposed}(g::BipartiteGraph, m::M) where {Transposed, M}
348370
DiCMOBiGraph{Transposed}(g, missing, m)
349371
end
350372

373+
invview(g::DiCMOBiGraph{Transposed}) where {Transposed} =
374+
DiCMOBiGraph{!Transposed}(invview(g.graph), g.ne, invview(g.matching))
375+
351376
Graphs.is_directed(::Type{<:DiCMOBiGraph}) = true
352377
Graphs.nv(g::DiCMOBiGraph{Transposed}) where {Transposed} = Transposed ? ndsts(g.graph) : nsrcs(g.graph)
353378
Graphs.vertices(g::DiCMOBiGraph{Transposed}) where {Transposed} = Transposed ? 𝑑vertices(g.graph) : 𝑠vertices(g.graph)
@@ -360,6 +385,8 @@ struct CMONeighbors{Transposed, V}
360385
end
361386

362387
Graphs.outneighbors(g::DiCMOBiGraph{false}, v) = CMONeighbors{false}(g, v)
388+
Graphs.inneighbors(g::DiCMOBiGraph{false}, v) = CMONeighbors{true}(invview(g), v)
389+
Graphs.all_neighbors(g::DiCMOBiGraph{true}, v::Integer) = 𝑠neighbors(g.graph, v)
363390
Base.iterate(c::CMONeighbors{false}) = iterate(c, (c.g.graph.fadjlist[c.v],))
364391
function Base.iterate(c::CMONeighbors{false}, (l, state...))
365392
while true
@@ -376,13 +403,16 @@ function Base.iterate(c::CMONeighbors{false}, (l, state...))
376403
return vsrc, (l, r[2])
377404
end
378405
end
406+
Base.length(c::CMONeighbors{false}) = count(_->true, c)
379407

380408
lift(f, x) = (x === unassigned || isnothing(x)) ? nothing : f(x)
381409

382410
_vsrc(c::CMONeighbors{true}) = c.g.matching[c.v]
383411
_neighbors(c::CMONeighbors{true}) = lift(vsrc->c.g.graph.fadjlist[vsrc], _vsrc(c))
384412
Base.length(c::CMONeighbors{true}) = something(lift(length, _neighbors(c)), 1) - 1
385413
Graphs.inneighbors(g::DiCMOBiGraph{true}, v) = CMONeighbors{true}(g, v)
414+
Graphs.outneighbors(g::DiCMOBiGraph{true}, v) = CMONeighbors{false}(invview(g), v)
415+
Graphs.all_neighbors(g::DiCMOBiGraph{true}, v::Integer) = 𝑑neighbors(g.graph, v)
386416
Base.iterate(c::CMONeighbors{true}) = lift(ns->iterate(c, (ns,)), _neighbors(c))
387417
function Base.iterate(c::CMONeighbors{true}, (l, state...))
388418
while true
@@ -396,16 +426,15 @@ function Base.iterate(c::CMONeighbors{true}, (l, state...))
396426
end
397427
end
398428

429+
399430
_edges(g::DiCMOBiGraph{Transposed}) where Transposed = Transposed ?
400431
((w=>v for w in inneighbors(g, v)) for v in vertices(g)) :
401432
((v=>w for w in outneighbors(g, v)) for v in vertices(g))
402-
_count(c::CMONeighbors{true}) = length(c)
403-
_count(c::CMONeighbors{false}) = count(_->true, c)
404433

405434
Graphs.edges(g::DiCMOBiGraph) = (Graphs.SimpleEdge(p) for p in Iterators.flatten(_edges(g)))
406435
function Graphs.ne(g::DiCMOBiGraph)
407436
if g.ne === missing
408-
g.ne = mapreduce(x->_count(x.iter), +, _edges(g))
437+
g.ne = mapreduce(x->length(x.iter), +, _edges(g))
409438
end
410439
return g.ne
411440
end

src/systems/systemstructure.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ Base.eltype(::DiffGraph) = Union{Int, Nothing}
114114
Base.size(dg::DiffGraph) = size(dg.primal_to_diff)
115115
Base.length(dg::DiffGraph) = length(dg.primal_to_diff)
116116
Base.getindex(dg::DiffGraph, var::Integer) = dg.primal_to_diff[var]
117+
Base.getindex(dg::DiffGraph, a::AbstractArray) = [dg[x] for x in a]
118+
117119
function Base.setindex!(dg::DiffGraph, val::Union{Integer, Nothing}, var::Integer)
118120
if dg.diff_to_primal !== nothing
119121
old_pd = dg.primal_to_diff[var]
@@ -132,7 +134,7 @@ Base.iterate(dg::DiffGraph, state...) = iterate(dg.primal_to_diff, state...)
132134

133135
function complete(dg::DiffGraph)
134136
dg.diff_to_primal !== nothing && return dg
135-
diff_to_primal = zeros(Int, length(dg.primal_to_diff))
137+
diff_to_primal = Union{Int, Nothing}[nothing for _ = 1:length(dg.primal_to_diff)]
136138
for (var, diff) in edges(dg)
137139
diff_to_primal[diff] = var
138140
end

0 commit comments

Comments
 (0)