Skip to content

Commit 7fd455d

Browse files
authored
Merge pull request #1404 from SciML/myb/odaefix
Fix ODAEProblem codegen
2 parents 92043ac + fe92b60 commit 7fd455d

File tree

4 files changed

+156
-58
lines changed

4 files changed

+156
-58
lines changed

src/bipartite_graph.jl

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module BipartiteGraphs
22

33
export BipartiteEdge, BipartiteGraph, DiCMOBiGraph, Unassigned, unassigned,
44
Matching, ResidualCMOGraph, InducedCondensationGraph, maximal_matching,
5-
construct_augmenting_path!
5+
construct_augmenting_path!, MatchedCondensationGraph
66

77
export 𝑠vertices, 𝑑vertices, has_𝑠vertex, has_𝑑vertex, 𝑠neighbors, 𝑑neighbors,
88
𝑠edges, 𝑑edges, nsrcs, ndsts, SRC, DST, set_neighbors!, invview,
@@ -516,4 +516,80 @@ end
516516
Graphs.has_edge(g::DiCMOBiGraph{true}, a, b) = a in inneighbors(g, b)
517517
Graphs.has_edge(g::DiCMOBiGraph{false}, a, b) = b in outneighbors(g, a)
518518

519+
# Condensation Graphs
520+
abstract type AbstractCondensationGraph <: AbstractGraph{Int}; end
521+
function (T::Type{<:AbstractCondensationGraph})(g, sccs::Vector{Union{Int, Vector{Int}}})
522+
scc_assignment = Vector{Int}(undef, isa(g, BipartiteGraph) ? ndsts(g) : nv(g))
523+
for (i, c) in enumerate(sccs)
524+
for v in c
525+
scc_assignment[v] = i
526+
end
527+
end
528+
T(g, sccs, scc_assignment)
529+
end
530+
(T::Type{<:AbstractCondensationGraph})(g, sccs::Vector{Vector{Int}}) =
531+
T(g, Vector{Union{Int, Vector{Int}}}(sccs))
532+
533+
Graphs.is_directed(::Type{<:AbstractCondensationGraph}) = true
534+
Graphs.nv(icg::AbstractCondensationGraph) = length(icg.sccs)
535+
Graphs.vertices(icg::AbstractCondensationGraph) = Base.OneTo(nv(icg))
536+
537+
"""
538+
struct MatchedCondensationGraph
539+
540+
For some bipartite-graph and an orientation induced on its destination contraction,
541+
records the condensation DAG of the digraph formed by the orientation. I.e. this
542+
is a DAG of connected components formed by the destination vertices of some
543+
underlying bipartite graph.
544+
N.B.: This graph does not store explicit neighbor relations of the sccs.
545+
Therefor, the edge multiplicity is derived from the underlying bipartite graph,
546+
i.e. this graph is not strict.
547+
"""
548+
struct MatchedCondensationGraph{G <: DiCMOBiGraph} <: AbstractCondensationGraph
549+
graph::G
550+
# Records the members of a strongly connected component. For efficiency,
551+
# trivial sccs (with one vertex member) are stored inline. Note: the sccs
552+
# here need not be stored in topological order.
553+
sccs::Vector{Union{Int, Vector{Int}}}
554+
# Maps the vertices back to the scc of which they are a part
555+
scc_assignment::Vector{Int}
556+
end
557+
558+
559+
Graphs.outneighbors(mcg::MatchedCondensationGraph, cc::Integer) =
560+
Iterators.flatten((mcg.scc_assignment[v′] for v′ in outneighbors(mcg.graph, v) if mcg.scc_assignment[v′] != cc) for v in mcg.sccs[cc])
561+
562+
Graphs.inneighbors(mcg::MatchedCondensationGraph, cc::Integer) =
563+
Iterators.flatten((mcg.scc_assignment[v′] for v′ in inneighbors(mcg.graph, v) if mcg.scc_assignment[v′] != cc) for v in mcg.sccs[cc])
564+
565+
"""
566+
struct InducedCondensationGraph
567+
568+
For some bipartite-graph and a topologicall sorted list of connected components,
569+
represents the condensation DAG of the digraph formed by the orientation. I.e. this
570+
is a DAG of connected components formed by the destination vertices of some
571+
underlying bipartite graph.
572+
N.B.: This graph does not store explicit neighbor relations of the sccs.
573+
Therefor, the edge multiplicity is derived from the underlying bipartite graph,
574+
i.e. this graph is not strict.
575+
"""
576+
struct InducedCondensationGraph{G <: BipartiteGraph} <: AbstractCondensationGraph
577+
graph::G
578+
# Records the members of a strongly connected component. For efficiency,
579+
# trivial sccs (with one vertex member) are stored inline. Note: the sccs
580+
# here are stored in topological order.
581+
sccs::Vector{Union{Int, Vector{Int}}}
582+
# Maps the vertices back to the scc of which they are a part
583+
scc_assignment::Vector{Int}
584+
end
585+
586+
_neighbors(icg::InducedCondensationGraph, cc::Integer) =
587+
Iterators.flatten(Iterators.flatten(icg.graph.fadjlist[vsrc] for vsrc in icg.graph.badjlist[v]) for v in icg.sccs[cc])
588+
589+
Graphs.outneighbors(icg::InducedCondensationGraph, v::Integer) =
590+
(icg.scc_assignment[n] for n in _neighbors(icg, v) if icg.scc_assignment[n] > v)
591+
592+
Graphs.inneighbors(icg::InducedCondensationGraph, v::Integer) =
593+
(icg.scc_assignment[n] for n in _neighbors(icg, v) if icg.scc_assignment[n] < v)
594+
519595
end # module

src/compat/incremental_cycles.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Base.Iterators: repeated
2+
import Graphs.Experimental.Traversals: topological_sort
23

34
# Abstract Interface
45

src/structural_transformation/codegen.jl

Lines changed: 63 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ModelingToolkit: isdifferenceeq, has_continuous_events, generate_rootfindi
44

55
const MAX_INLINE_NLSOLVE_SIZE = 8
66

7-
function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs)
7+
function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, states_idxs)
88
s = structure(sys)
99
@unpack fullvars, graph = s
1010

@@ -42,54 +42,52 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs)
4242
# from previous partitions. Hence, we can build the dependency chain as we
4343
# traverse the partitions.
4444

45-
# `avars2dvars` maps a algebraic variable to its differential variable
46-
# dependencies.
47-
avars2dvars = Dict{Int,Set{Int}}()
48-
c = 0
49-
for scc in var_sccs
50-
v_residual = scc
51-
e_residual = [var_eq_matching[c] for c in v_residual if var_eq_matching[c] !== unassigned]
52-
# initialization
53-
for tvar in v_residual
54-
avars2dvars[tvar] = Set{Int}()
45+
var_rename = ones(Int64, ndsts(graph))
46+
nlsolve_vars = Int[]
47+
for i in nlsolve_scc_idxs, c in var_sccs[i]
48+
append!(nlsolve_vars, c)
49+
for v in c
50+
var_rename[v] = 0
5551
end
56-
for teq in e_residual
57-
c += 1
58-
for var in 𝑠neighbors(graph, teq)
59-
# Skip the tearing variables in the current partition, because
60-
# we are computing them from all the other states.
61-
Graphs.insorted(var, v_residual) && continue
62-
deps = get(avars2dvars, var, nothing)
63-
if deps === nothing # differential variable
64-
@assert !isalgvar(s, var)
65-
for tvar in v_residual
66-
push!(avars2dvars[tvar], var)
67-
end
68-
else # tearing variable from previous partitions
69-
@assert isalgvar(s, var)
70-
for tvar in v_residual
71-
union!(avars2dvars[tvar], avars2dvars[var])
72-
end
73-
end
52+
end
53+
masked_cumsum!(var_rename)
54+
55+
dig = DiCMOBiGraph{true}(graph, var_eq_matching)
56+
57+
fused_var_deps = map(1:ndsts(graph)) do v
58+
BitSet(var_rename[v′] for v′ in neighborhood(dig, v, Inf; dir=:in) if var_rename[v′] != 0)
59+
end
60+
61+
for scc in var_sccs
62+
if length(scc) >= 2
63+
deps = fused_var_deps[scc[1]]
64+
for c in 2:length(scc)
65+
union!(deps, fused_var_deps[c])
7466
end
7567
end
7668
end
7769

78-
dvrange = diffvars_range(s)
79-
dvar2idx = Dict(v=>i for (i, v) in enumerate(dvrange))
70+
nlsolve_eqs = BitSet(var_eq_matching[c]::Int for c in nlsolve_vars if var_eq_matching[c] !== unassigned)
71+
72+
var2idx = Dict(v => i for (i, v) in enumerate(states_idxs))
73+
nlsolve_vars_set = BitSet(nlsolve_vars)
74+
8075
I = Int[]; J = Int[]
8176
eqidx = 0
8277
for ieq in 𝑠vertices(graph)
83-
isalgeq(s, ieq) && continue
78+
ieq in nlsolve_eqs && continue
8479
eqidx += 1
8580
for ivar in 𝑠neighbors(graph, ieq)
86-
if isdiffvar(s, ivar)
81+
isdervar(s, ivar) && continue
82+
if var_rename[ivar] != 0
8783
push!(I, eqidx)
88-
push!(J, dvar2idx[ivar])
89-
elseif isalgvar(s, ivar)
90-
for dvar in avars2dvars[ivar]
84+
push!(J, var2idx[ivar])
85+
else
86+
for dvar in fused_var_deps[ivar]
87+
isdervar(s, dvar) && continue
88+
dvar in nlsolve_vars_set && continue
9189
push!(I, eqidx)
92-
push!(J, dvar2idx[dvar])
90+
push!(J, var2idx[dvar])
9391
end
9492
end
9593
end
@@ -123,7 +121,17 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
123121
params = setdiff(allvars, vars) # these are not the subject of the root finding
124122

125123
# splatting to tighten the type
126-
u0 = [map(var->get(u0map, var, 1e-3), vars)...]
124+
u0 = []
125+
for v in vars
126+
v in keys(u0map) || (push!(u0, 1e-3); continue)
127+
u = substitute(v, u0map)
128+
for i in 1:length(u0map)
129+
u = substitute(u, u0map)
130+
u isa Number && (push!(u0, u); break)
131+
end
132+
u isa Number || error("$v doesn't have a default.")
133+
end
134+
u0 = [u0...]
127135
# specialize on the scalar case
128136
isscalar = length(u0) == 1
129137
u0 = isscalar ? u0[1] : SVector(u0...)
@@ -175,23 +183,30 @@ function build_torn_function(
175183
s = structure(sys)
176184
@unpack fullvars = s
177185
var_eq_matching, var_sccs = algebraic_variables_scc(sys)
186+
condensed_graph = MatchedCondensationGraph(
187+
DiCMOBiGraph{true}(complete(s.graph), complete(var_eq_matching)), var_sccs)
188+
toporder = topological_sort_by_dfs(condensed_graph)
189+
var_sccs = var_sccs[toporder]
178190

179-
states = map(i->s.fullvars[i], diffvars_range(s))
180-
mass_matrix_diag = ones(length(states))
191+
states_idxs = collect(diffvars_range(s))
192+
mass_matrix_diag = ones(length(states_idxs))
181193
torn_expr = []
182194
defs = defaults(sys)
195+
nlsolve_scc_idxs = Int[]
183196

184197
needs_extending = false
185-
for scc in var_sccs
186-
torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
187-
torn_eqs = [eqs[var_eq_matching[var]] for var in scc if var_eq_matching[var] !== unassigned]
198+
for (i, scc) in enumerate(var_sccs)
199+
#torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
200+
torn_vars_idxs = Int[var for var in scc if var_eq_matching[var] !== unassigned]
201+
torn_eqs = [eqs[var_eq_matching[var]] for var in torn_vars_idxs]
188202
isempty(torn_eqs) && continue
189203
if length(torn_eqs) <= max_inlining_size
190-
append!(torn_expr, gen_nlsolve(torn_eqs, torn_vars, defs, checkbounds=checkbounds))
204+
append!(torn_expr, gen_nlsolve(torn_eqs, s.fullvars[torn_vars_idxs], defs, checkbounds=checkbounds))
205+
push!(nlsolve_scc_idxs, i)
191206
else
192207
needs_extending = true
193208
append!(rhss, map(x->x.rhs, torn_eqs))
194-
append!(states, torn_vars)
209+
append!(states_idxs, torn_vars_idxs)
195210
append!(mass_matrix_diag, zeros(length(torn_eqs)))
196211
end
197212
end
@@ -205,7 +220,8 @@ function build_torn_function(
205220
rhss
206221
)
207222

208-
syms = map(Symbol, states)
223+
states = s.fullvars[states_idxs]
224+
syms = map(Symbol, states_idxs)
209225
pre = get_postprocess_fbody(sys)
210226

211227
expr = SymbolicUtils.Code.toexpr(
@@ -237,7 +253,7 @@ function build_torn_function(
237253

238254
ODEFunction{true}(
239255
@RuntimeGeneratedFunction(expr),
240-
sparsity = torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs),
256+
sparsity = jacobian_sparsity ? torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, states_idxs) : nothing,
241257
syms = syms,
242258
observed = observedfun,
243259
mass_matrix = mass_matrix,

src/structural_transformation/tearing.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,24 @@ function contract_variables(graph::BipartiteGraph, var_eq_matching::Matching, el
3838
[var_rename[v′] for v′ in neighborhood(dig, v, Inf; dir=:in) if var_rename[v′] != 0]
3939
end
4040

41-
new_fadjlist = Vector{Int}[
42-
let new_list = Vector{Int}()
43-
for v in graph.fadjlist[i]
44-
if var_rename[v] != 0
45-
push!(new_list, var_rename[v])
46-
else
47-
append!(new_list, var_deps[v])
41+
nelim = length(eliminated_variables)
42+
newgraph = BipartiteGraph(nsrcs(graph) - nelim, ndsts(graph) - nelim)
43+
for e in 𝑠vertices(graph)
44+
ne = eq_rename[e]
45+
ne == 0 && continue
46+
for v in 𝑠neighbors(graph, e)
47+
newvar = var_rename[v]
48+
if newvar != 0
49+
add_edge!(newgraph, ne, newvar)
50+
else
51+
for nv in var_deps[v]
52+
add_edge!(newgraph, ne, nv)
4853
end
4954
end
50-
new_list
51-
end for i = 1:nsrcs(graph) if eq_rename[i] != 0]
55+
end
56+
end
5257

53-
return BipartiteGraph(new_fadjlist, ndsts(graph) - length(eliminated_variables))
58+
return newgraph
5459
end
5560

5661
"""

0 commit comments

Comments
 (0)