Skip to content

Commit 875c6bb

Browse files
committed
Finish tearing refactor
1 parent 3f908f8 commit 875c6bb

File tree

11 files changed

+359
-359
lines changed

11 files changed

+359
-359
lines changed

src/bipartite_graph.jl

Lines changed: 123 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module BipartiteGraphs
22

33
export BipartiteEdge, BipartiteGraph, DiCMOBiGraph, Unassigned, unassigned,
4-
Matching, ResidualCMOGraph
4+
Matching, ResidualCMOGraph, InducedCondensationGraph, maximal_matching,
5+
construct_augmenting_path!
56

67
export 𝑠vertices, 𝑑vertices, has_𝑠vertex, has_𝑑vertex, 𝑠neighbors, 𝑑neighbors,
78
𝑠edges, 𝑑edges, nsrcs, ndsts, SRC, DST, set_neighbors!, invview,
@@ -18,6 +19,14 @@ struct Unassigned
1819
global unassigned
1920
const unassigned = Unassigned.instance
2021
end
22+
# Behaves as a scalar
23+
Base.length(u::Unassigned) = 1
24+
Base.size(u::Unassigned) = ()
25+
Base.iterate(u::Unassigned) = (unassigned, nothing)
26+
Base.iterate(u::Unassigned, state) = nothing
27+
28+
Base.show(io::IO, ::Unassigned) =
29+
printstyled(io, "u"; color=:light_black)
2130

2231
struct Matching{U #=> :Unassigned =#, V<:AbstractVector} <: AbstractVector{Union{U, Int}}
2332
match::V
@@ -30,6 +39,7 @@ function Matching{V}(m::Matching) where {V}
3039
Matching{V}(convert(VUT, m.match),
3140
m.inv_match === nothing ? nothing : convert(VUT, m.inv_match))
3241
end
42+
Matching(m::Matching) = m
3343
Matching{U}(v::V) where {U, V<:AbstractVector} = Matching{U, V}(v, nothing)
3444
Matching{U}(v::V, iv::Union{V, Nothing}) where {U, V<:AbstractVector} = Matching{U, V}(v, iv)
3545
Matching(v::V) where {U, V<:AbstractVector{Union{U, Int}}} =
@@ -135,10 +145,13 @@ mutable struct BipartiteGraph{I<:Integer, M} <: Graphs.AbstractGraph{I}
135145
metadata::M
136146
end
137147
BipartiteGraph(ne::Integer, fadj::AbstractVector, badj::Union{AbstractVector,Integer}=maximum(maximum, fadj); metadata=nothing) = BipartiteGraph(ne, fadj, badj, metadata)
148+
BipartiteGraph(fadj::AbstractVector, badj::Union{AbstractVector,Integer}=maximum(maximum, fadj); metadata=nothing) =
149+
BipartiteGraph(mapreduce(length, +, fadj; init=0), fadj, badj, metadata)
138150

139151
@noinline require_complete(g::BipartiteGraph) = g.badjlist isa AbstractVector || throw(ArgumentError("The graph has no back edges. Use `complete`."))
140152

141153
function invview(g::BipartiteGraph)
154+
require_complete(g)
142155
BipartiteGraph(g.ne, g.badjlist, g.fadjlist)
143156
end
144157

@@ -215,7 +228,53 @@ ndsts(g::BipartiteGraph) = length(𝑑vertices(g))
215228
function Graphs.has_edge(g::BipartiteGraph, edge::BipartiteEdge)
216229
@unpack src, dst = edge
217230
(src in 𝑠vertices(g) && dst in 𝑑vertices(g)) || return false # edge out of bounds
218-
insorted(𝑠neighbors(src), dst)
231+
insorted(dst, 𝑠neighbors(g, src))
232+
end
233+
Base.in(edge::BipartiteEdge, g::BipartiteGraph) = Graphs.has_edge(g, edge)
234+
235+
### Maximal matching
236+
"""
237+
construct_augmenting_path!(m::Matching, g::BipartiteGraph, vsrc, dstfilter, vcolor=falses(ndsts(g)), ecolor=falses(nsrcs(g))) -> path_found::Bool
238+
239+
Try to construct an augmenting path in matching and if such a path is found,
240+
update the matching accordingly.
241+
"""
242+
function construct_augmenting_path!(matching::Matching, g::BipartiteGraph, vsrc, dstfilter, dcolor=falses(ndsts(g)), scolor=falses(nsrcs(g)))
243+
scolor[vsrc] = true
244+
245+
# if a `vdst` is unassigned and the edge `vsrc <=> vdst` exists
246+
for vdst in 𝑠neighbors(g, vsrc)
247+
if dstfilter(vdst) && matching[vdst] === unassigned
248+
matching[vdst] = vsrc
249+
return true
250+
end
251+
end
252+
253+
# for every `vsrc` such that edge `vsrc <=> vdst` exists and `vdst` is uncolored
254+
for vdst in 𝑠neighbors(g, vsrc)
255+
(dstfilter(vdst) && !dcolor[vdst]) || continue
256+
dcolor[vdst] = true
257+
if construct_augmenting_path!(matching, g, matching[vdst], dstfilter, dcolor, scolor)
258+
matching[vdst] = vsrc
259+
return true
260+
end
261+
end
262+
return false
263+
end
264+
265+
"""
266+
maximal_matching(g::BipartiteGraph, [srcfilter], [dstfilter])
267+
268+
For a bipartite graph `g`, construct a maximal matching of destination to source
269+
vertices, subject to the constraint that vertices for which `srcfilter` or `dstfilter`,
270+
return `false` may not be matched.
271+
"""
272+
function maximal_matching(g::BipartiteGraph, srcfilter=vsrc->true, dstfilter=vdst->true)
273+
matching = Matching(ndsts(g))
274+
foreach(Iterators.filter(srcfilter, 𝑠vertices(g))) do vsrc
275+
construct_augmenting_path!(matching, g, vsrc, dstfilter)
276+
end
277+
return matching
219278
end
220279

221280
###
@@ -508,4 +567,66 @@ function Graphs.neighbors(rcg::ResidualCMOGraph, v::Integer)
508567
invview(rcg.matching)[vsrc] === unassigned))
509568
end
510569

570+
# TODO: Fix the function in Graphs to do this instead
571+
function Graphs.neighborhood(rcg::ResidualCMOGraph, v::Integer)
572+
worklist = Int[v]
573+
ns = BitSet()
574+
while !isempty(worklist)
575+
v′ = popfirst!(worklist)
576+
for n in neighbors(rcg, v′)
577+
if !(n in ns)
578+
push!(ns, n)
579+
push!(worklist, n)
580+
end
581+
end
582+
end
583+
return ns
584+
end
585+
586+
"""
587+
struct InducedCondensationGraph
588+
589+
For some bipartite-graph and an orientation induced on its destination contraction,
590+
records the condensation DAG of the digraph formed by the orientation. I.e. this
591+
is a DAG of connected components formed by the destination vertices of some
592+
underlying bipartite graph.
593+
594+
N.B.: This graph does not store explicit neighbor relations of the sccs.
595+
Therefor, the edge multiplicity is derived from the underlying bipartite graph,
596+
i.e. this graph is not strict.
597+
"""
598+
struct InducedCondensationGraph{G <: BipartiteGraph} <: AbstractGraph{Vector{Union{Int, Vector{Int}}}}
599+
graph::G
600+
# Records the members of a strongly connected component. For efficiency,
601+
# trivial sccs (with one vertex member) are stored inline. Note: the sccs
602+
# here are stored in topological order.
603+
sccs::Vector{Union{Int, Vector{Int}}}
604+
# Maps the vertices back to the scc of which they are a part
605+
scc_assignment::Vector{Int}
606+
end
607+
608+
function InducedCondensationGraph(g::BipartiteGraph, sccs::Vector{Union{Int, Vector{Int}}})
609+
scc_assignment = Vector{Int}(undef, ndsts(g))
610+
for (i, c) in enumerate(sccs)
611+
for v in c
612+
scc_assignment[v] = i
613+
end
614+
end
615+
InducedCondensationGraph(g, sccs, scc_assignment)
616+
end
617+
618+
Graphs.is_directed(::Type{<:InducedCondensationGraph}) = true
619+
Graphs.nv(icg::InducedCondensationGraph) = length(icg.sccs)
620+
Graphs.vertices(icg::InducedCondensationGraph) = icg.sccs
621+
622+
_neighbors(icg::InducedCondensationGraph, cc::Integer) =
623+
Iterators.flatten(Iterators.flatten(rcg.graph.fadjlist[vsrc] for vsrc in rcg.graph.badjlist[v]) for v in icg.sccs[cc])
624+
625+
Graphs.outneighbors(rcg::InducedCondensationGraph, v::Integer) =
626+
(scc_assignment[n] for n in _neighbors(rcg, v) if scc_assignment[n] > v)
627+
628+
Graphs.inneighbors(rcg::InducedCondensationGraph, v::Integer) =
629+
(scc_assignment[n] for n in _neighbors(rcg, v) if scc_assignment[n] < v)
630+
631+
511632
end # module

src/structural_transformation/bipartite_tearing/modia_tearing.jl

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,30 @@ function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int}
4242
end
4343
end
4444

45-
vSolved = filter(v->G.matching[v] !== unassigned, topological_sort(ict))
46-
inv_matching = Union{Missing, Int}[missing for _ = 1:nv(G)]
47-
for (v, eq) in pairs(G.matching)
48-
eq === unassigned && continue
49-
inv_matching[v] = eq
45+
return ict
46+
end
47+
48+
"""
49+
tear_graph_modia(sys) -> sys
50+
51+
Tear the bipartite graph in a system. End users are encouraged to call [`structural_simplify`](@ref)
52+
instead, which calls this function internally.
53+
"""
54+
function tear_graph_modia(graph::BipartiteGraph, solvable_graph::BipartiteGraph; varfilter=v->true, eqfilter=eq->true)
55+
var_eq_matching = complete(maximal_matching(graph, eqfilter, varfilter))
56+
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, var_eq_matching)
57+
58+
for vars in var_sccs
59+
filtered_vars = filter(varfilter, vars)
60+
ieqs = Int[var_eq_matching[v] for v in filtered_vars if var_eq_matching[v] !== unassigned]
61+
62+
ict = IncrementalCycleTracker(DiCMOBiGraph{true}(graph); dir=:in)
63+
tearEquations!(ict, solvable_graph.fadjlist, ieqs, filtered_vars)
64+
65+
for var in vars
66+
var_eq_matching[var] = ict.graph.matching[var]
67+
end
5068
end
51-
eSolved = getindex.(Ref(inv_matching), vSolved)
52-
vTear = setdiff(vs, vSolved)
53-
eResidue = setdiff(es, eSolved)
54-
return (eSolved, vSolved, eResidue, vTear)
69+
70+
return var_eq_matching
5571
end

src/structural_transformation/codegen.jl

Lines changed: 24 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ using LinearAlgebra
22

33
const MAX_INLINE_NLSOLVE_SIZE = 8
44

5-
function torn_system_jacobian_sparsity(sys)
5+
function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs)
66
s = structure(sys)
7-
@unpack fullvars, graph, partitions = s
7+
@unpack fullvars, graph = s
88

99
# The sparsity pattern of `nlsolve(f, u, p)` w.r.t `p` is difficult to
1010
# determine in general. Consider the "simplest" case, a linear system. We
@@ -44,8 +44,9 @@ function torn_system_jacobian_sparsity(sys)
4444
# dependencies.
4545
avars2dvars = Dict{Int,Set{Int}}()
4646
c = 0
47-
for partition in partitions
48-
@unpack e_residual, v_residual = partition
47+
for scc in var_sccs
48+
v_residual = scc
49+
e_residual = [var_eq_matching[c] for c in v_residual if var_eq_matching[c] !== unassigned]
4950
# initialization
5051
for tvar in v_residual
5152
avars2dvars[tvar] = Set{Int}()
@@ -94,41 +95,6 @@ function torn_system_jacobian_sparsity(sys)
9495
sparse(I, J, true)
9596
end
9697

97-
"""
98-
partitions_dag(s::SystemStructure)
99-
100-
Return a DAG (sparse matrix) of partitions to use for parallelism.
101-
"""
102-
function partitions_dag(s::SystemStructure)
103-
@unpack partitions, graph = s
104-
105-
# `partvars[i]` contains all the states that appear in `partitions[i]`
106-
partvars = map(partitions) do partition
107-
ipartvars = Set{Int}()
108-
for req in partition.e_residual
109-
union!(ipartvars, 𝑠neighbors(graph, req))
110-
end
111-
ipartvars
112-
end
113-
114-
I, J = Int[], Int[]
115-
n = length(partitions)
116-
for (i, partition) in enumerate(partitions)
117-
for j in i+1:n
118-
# The only way for a later partition `j` to depend on an earlier
119-
# partition `i` is when `partvars[j]` contains one of tearing
120-
# variables of partition `i`.
121-
if !isdisjoint(partvars[j], partition.v_residual)
122-
# j depends on i
123-
push!(I, i)
124-
push!(J, j)
125-
end
126-
end
127-
end
128-
129-
sparse(I, J, true, n, n)
130-
end
131-
13298
"""
13399
exprs = gen_nlsolve(eqs::Vector{Equation}, vars::Vector, u0map::Dict; checkbounds = true)
134100
@@ -205,19 +171,20 @@ function build_torn_function(
205171
end
206172

207173
s = structure(sys)
208-
@unpack fullvars, partitions = s
174+
@unpack fullvars = s
175+
var_eq_matching, var_sccs = algebraic_variables_scc(sys)
209176

210177
states = map(i->s.fullvars[i], diffvars_range(s))
211178
mass_matrix_diag = ones(length(states))
212179
torn_expr = []
213180
defs = defaults(sys)
214181

215182
needs_extending = false
216-
for p in partitions
217-
@unpack e_residual, v_residual = p
218-
torn_eqs = eqs[e_residual]
219-
torn_vars = fullvars[v_residual]
220-
if length(e_residual) <= max_inlining_size
183+
for scc in var_sccs
184+
torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
185+
torn_eqs = [eqs[var_eq_matching[var]] for var in scc if var_eq_matching[var] !== unassigned]
186+
isempty(torn_eqs) && continue
187+
if length(torn_eqs) <= max_inlining_size
221188
append!(torn_expr, gen_nlsolve(torn_eqs, torn_vars, defs, checkbounds=checkbounds))
222189
else
223190
needs_extending = true
@@ -260,15 +227,15 @@ function build_torn_function(
260227
observedfun = let sys = sys, dict = Dict()
261228
function generated_observed(obsvar, u, p, t)
262229
obs = get!(dict, value(obsvar)) do
263-
build_observed_function(sys, obsvar, checkbounds=checkbounds)
230+
build_observed_function(sys, obsvar, var_eq_matching, var_sccs, checkbounds=checkbounds)
264231
end
265232
obs(u, p, t)
266233
end
267234
end
268235

269236
ODEFunction{true}(
270237
@RuntimeGeneratedFunction(expr),
271-
sparsity = torn_system_jacobian_sparsity(sys),
238+
sparsity = torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs),
272239
syms = syms,
273240
observed = observedfun,
274241
mass_matrix = mass_matrix,
@@ -277,24 +244,24 @@ function build_torn_function(
277244
end
278245

279246
"""
280-
find_solve_sequence(partitions, vars)
247+
find_solve_sequence(sccs, vars)
281248
282249
given a set of `vars`, find the groups of equations we need to solve for
283250
to obtain the solution to `vars`
284251
"""
285-
function find_solve_sequence(partitions, vars)
286-
subset = filter(x -> !isdisjoint(x.v_residual, vars), partitions)
252+
function find_solve_sequence(sccs, vars)
253+
subset = filter(i -> !isdisjoint(sccs[i], vars), 1:length(sccs))
287254
isempty(subset) && return []
288-
vars′ = mapreduce(x->x.v_residual, union, subset)
255+
vars′ = mapreduce(i->sccs[i], union, subset)
289256
if vars′ == vars
290257
return subset
291258
else
292-
return find_solve_sequence(partitions, vars′)
259+
return find_solve_sequence(sccs, vars′)
293260
end
294261
end
295262

296263
function build_observed_function(
297-
sys, ts;
264+
sys, ts, var_eq_matching, var_sccs;
298265
expression=false,
299266
output_type=Array,
300267
checkbounds=true
@@ -311,7 +278,7 @@ function build_observed_function(
311278
dep_vars = collect(setdiff(vars, ivs))
312279

313280
s = structure(sys)
314-
@unpack partitions, fullvars, graph = s
281+
@unpack fullvars, graph = s
315282
diffvars = map(i->fullvars[i], diffvars_range(s))
316283
algvars = map(i->fullvars[i], algvars_range(s))
317284

@@ -339,12 +306,12 @@ function build_observed_function(
339306
end
340307

341308
varidxs = findall(x->x in required_algvars, fullvars)
342-
subset = find_solve_sequence(partitions, varidxs)
309+
subset = find_solve_sequence(var_sccs, varidxs)
343310
if !isempty(subset)
344311
eqs = equations(sys)
345312

346-
torn_eqs = map(idxs-> eqs[idxs.e_residual], subset)
347-
torn_vars = map(idxs->fullvars[idxs.v_residual], subset)
313+
torn_eqs = map(i->map(v->eqs[var_eq_matching[v]], var_sccs[i]), subset)
314+
torn_vars = map(i->map(v->fullvars[v], var_sccs[i]), subset)
348315
u0map = defaults(sys)
349316
solves = gen_nlsolve.(torn_eqs, torn_vars, (u0map,); checkbounds=checkbounds)
350317
else

src/structural_transformation/pantelides.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ function pantelides!(graph, var_to_diff; maxiters = 8000)
101101
fill!(vcolor, false)
102102
resize!(ecolor, neqs)
103103
fill!(ecolor, false)
104-
pathfound = find_augmenting_path(graph, eq′, var_eq_matching, varwhitelist, vcolor, ecolor)
104+
pathfound = construct_augmenting_path!(var_eq_matching, graph, eq′, v->varwhitelist[v], vcolor, ecolor)
105105
pathfound && break # terminating condition
106106
for var in eachindex(vcolor); vcolor[var] || continue
107107
# introduce a new variable

0 commit comments

Comments
 (0)