Skip to content

Commit 23cafd0

Browse files
authored
Merge pull request #1384 from Keno/kf/state_selection
Partial State Selection
2 parents a18de30 + df13e04 commit 23cafd0

26 files changed

+718
-313
lines changed

docs/src/tutorials/tearing_parallelism.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ investigate what this means:
178178

179179
```julia
180180
using ModelingToolkit.BipartiteGraphs
181-
big_rc = initialize_system_structure(expand_connections(big_rc))
182-
inc_org = BipartiteGraphs.incidence_matrix(structure(big_rc).graph)
181+
ts = TearingState(expand_connections(big_rc))
182+
inc_org = BipartiteGraphs.incidence_matrix(ts.graph)
183183
blt_org = StructuralTransformations.sorted_incidence_matrix(big_rc, only_algeqs=true, only_algvars=true)
184184
blt_reduced = StructuralTransformations.sorted_incidence_matrix(sys, only_algeqs=true, only_algvars=true)
185185
```
@@ -190,7 +190,7 @@ The figure on the left is the original incidence matrix of the algebraic equatio
190190
Notice that the original formulation of the model has dependencies between different
191191
equations, and so the full set of equations must be solved together. That exposes
192192
no parallelism. However, the Block Lower Triangular (BLT) transformation exposes
193-
independent blocks. This is then further impoved by the tearing process, which
193+
independent blocks. This is then further improved by the tearing process, which
194194
removes 90% of the equations and transforms the nonlinear equations into 50
195195
independent blocks *which can now all be solved in parallel*. The conclusion
196196
is that, your attempts to parallelize are neigh: performing parallelism after

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ export calculate_factorized_W, generate_factorized_W
193193
export calculate_hessian, generate_hessian
194194
export calculate_massmatrix, generate_diffusion_function
195195
export stochastic_integral_transform
196-
export initialize_system_structure
196+
export TearingState, StateSelectionState
197197
export generate_difference_cb
198198

199199
export BipartiteGraph, equation_dependencies, variable_dependencies

src/bipartite_graph.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,25 @@ function Graphs.add_edge!(g::BipartiteGraph, edge::BipartiteEdge, md=NO_METADATA
306306
return true # edge successfully added
307307
end
308308

309+
Graphs.rem_edge!(g::BipartiteGraph, i::Integer, j::Integer) =
310+
Graphs.rem_edge!(g, BipartiteEdge(i, j))
311+
function Graphs.rem_edge!(g::BipartiteGraph, edge::BipartiteEdge)
312+
@unpack fadjlist, badjlist = g
313+
s, d = src(edge), dst(edge)
314+
(has_𝑠vertex(g, s) && has_𝑑vertex(g, d)) || error("edge ($edge) out of range.")
315+
@inbounds list = fadjlist[s]
316+
index = searchsortedfirst(list, d)
317+
@inbounds (index <= length(list) && list[index] == d) || error("graph does not have edge $edge")
318+
deleteat!(list, index)
319+
g.ne -= 1
320+
if badjlist isa AbstractVector
321+
@inbounds list = badjlist[d]
322+
index = searchsortedfirst(list, s)
323+
deleteat!(list, index)
324+
end
325+
return true # edge successfully deleted
326+
end
327+
309328
function Graphs.add_vertex!(g::BipartiteGraph{T}, type::VertType) where T
310329
if type === DST
311330
if g.badjlist isa AbstractVector
@@ -322,10 +341,23 @@ function Graphs.add_vertex!(g::BipartiteGraph{T}, type::VertType) where T
322341
end
323342

324343
function set_neighbors!(g::BipartiteGraph, i::Integer, new_neighbors::AbstractVector)
325-
old_nneighbors = length(g.fadjlist[i])
344+
old_neighbors = g.fadjlist[i]
345+
old_nneighbors = length(old_neighbors)
326346
new_nneighbors = length(new_neighbors)
327347
g.fadjlist[i] = new_neighbors
328348
g.ne += new_nneighbors - old_nneighbors
349+
if isa(g.badjlist, AbstractVector)
350+
for n in old_neighbors
351+
@inbounds list = g.badjlist[n]
352+
index = searchsortedfirst(list, i)
353+
deleteat!(list, index)
354+
end
355+
for n in new_neighbors
356+
@inbounds list = g.badjlist[n]
357+
index = searchsortedfirst(list, i)
358+
insert!(list, index, i)
359+
end
360+
end
329361
end
330362

331363
###

src/structural_transformation/StructuralTransformations.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,24 @@ using SymbolicUtils.Rewriters
1010
using SymbolicUtils: similarterm, istree
1111

1212
using ModelingToolkit
13-
using ModelingToolkit: ODESystem, AbstractSystem,var_from_nested_derivative, Differential,
13+
using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Differential,
1414
states, equations, vars, Symbolic, diff2term, value,
1515
operation, arguments, Sym, Term, simplify, solve_for,
1616
isdiffeq, isdifferential, isinput,
1717
empty_substitutions, get_substitutions,
18-
get_structure, get_iv, independent_variables,
19-
has_structure, defaults, InvalidSystemException,
18+
get_tearing_state, get_iv, independent_variables,
19+
has_tearing_state, defaults, InvalidSystemException,
2020
ExtraEquationsSystemException,
2121
ExtraVariablesSystemException,
2222
get_postprocess_fbody, vars!,
2323
IncrementalCycleTracker, add_edge_checked!, topological_sort,
24-
invalidate_cache!, Substitutions
24+
invalidate_cache!, Substitutions, get_or_construct_tearing_state
2525

2626
using ModelingToolkit.BipartiteGraphs
2727
import .BipartiteGraphs: invview
2828
using Graphs
2929
using ModelingToolkit.SystemStructures
30+
using ModelingToolkit.SystemStructures: algeqs
3031

3132
using ModelingToolkit.DiffEqBase
3233
using ModelingToolkit.StaticArrays
@@ -38,10 +39,10 @@ using SparseArrays
3839

3940
using NonlinearSolve
4041

41-
export tearing, dae_index_lowering, check_consistency
42-
export tearing_assignments, tearing_substitution
42+
export tearing, partial_state_selection, dae_index_lowering, check_consistency
4343
export build_torn_function, build_observed_function, ODAEProblem
44-
export sorted_incidence_matrix
44+
export sorted_incidence_matrix, pantelides!, tearing_reassemble, find_solvables!
45+
export tearing_assignments, tearing_substitution
4546
export torn_system_jacobian_sparsity
4647
export full_equations
4748

@@ -50,6 +51,7 @@ include("pantelides.jl")
5051
include("bipartite_tearing/modia_tearing.jl")
5152
include("tearing.jl")
5253
include("symbolics_tearing.jl")
54+
include("partial_state_selection.jl")
5355
include("codegen.jl")
5456

5557
end # module

src/structural_transformation/bipartite_tearing/modia_tearing.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
#
1414
################################################
1515

16+
function try_assign_eq!(ict::IncrementalCycleTracker, vj::Integer, eq::Integer)
17+
G = ict.graph
18+
add_edge_checked!(ict, Iterators.filter(!=(vj), 𝑠neighbors(G.graph, eq)), vj) do G
19+
G.matching[vj] = eq
20+
G.ne += length(𝑠neighbors(G.graph, eq)) - 1
21+
end
22+
end
23+
1624
"""
1725
(eSolved, vSolved, eResidue, vTear) = tearEquations!(td, Gsolvable, es, vs)
1826
@@ -33,10 +41,7 @@ function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int}
3341
for eq in es # iterate only over equations that are not in eSolvedFixed
3442
for vj in Gsolvable[eq]
3543
if G.matching[vj] === unassigned && (vj in vActive)
36-
r = add_edge_checked!(ict, Iterators.filter(!=(vj), 𝑠neighbors(G.graph, eq)), vj) do G
37-
G.matching[vj] = eq
38-
G.ne += length(𝑠neighbors(G.graph, eq)) - 1
39-
end
44+
r = try_assign_eq!(ict, vj, eq)
4045
r && break
4146
end
4247
end
@@ -45,6 +50,15 @@ function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int}
4550
return ict
4651
end
4752

53+
function tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, eqs, vars)
54+
ict = IncrementalCycleTracker(DiCMOBiGraph{true}(graph); dir=:in)
55+
tearEquations!(ict, solvable_graph.fadjlist, eqs, vars)
56+
for var in vars
57+
var_eq_matching[var] = ict.graph.matching[var]
58+
end
59+
return nothing
60+
end
61+
4862
"""
4963
tear_graph_modia(sys) -> sys
5064
@@ -58,13 +72,10 @@ function tear_graph_modia(graph::BipartiteGraph, solvable_graph::BipartiteGraph;
5872
for vars in var_sccs
5973
filtered_vars = filter(varfilter, vars)
6074
ieqs = Int[var_eq_matching[v] for v in filtered_vars if var_eq_matching[v] !== unassigned]
61-
62-
ict = IncrementalCycleTracker(DiCMOBiGraph{true}(graph); dir=:in)
63-
tearEquations!(ict, solvable_graph.fadjlist, ieqs, filtered_vars)
64-
6575
for var in vars
66-
var_eq_matching[var] = ict.graph.matching[var]
76+
var_eq_matching[var] = unassigned
6777
end
78+
tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, ieqs, filtered_vars)
6879
end
6980

7081
return var_eq_matching

src/structural_transformation/codegen.jl

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ using ModelingToolkit: isdifferenceeq, has_continuous_events, generate_rootfindi
44

55
const MAX_INLINE_NLSOLVE_SIZE = 8
66

7-
function torn_system_with_nlsolve_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs)
8-
s = structure(sys)
9-
@unpack fullvars, graph = s
7+
function torn_system_with_nlsolve_jacobian_sparsity(state, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs)
8+
fullvars = state.fullvars
9+
graph = state.structure.graph
1010

1111
# The sparsity pattern of `nlsolve(f, u, p)` w.r.t `p` is difficult to
1212
# determine in general. Consider the "simplest" case, a linear system. We
@@ -73,6 +73,7 @@ function torn_system_with_nlsolve_jacobian_sparsity(sys, var_eq_matching, var_sc
7373
nlsolve_vars_set = BitSet(nlsolve_vars)
7474

7575
I = Int[]; J = Int[]
76+
s = state.structure
7677
for ieq in 𝑠vertices(graph)
7778
nieq = get(eqs2idx, ieq, 0)
7879
nieq == 0 && continue
@@ -244,15 +245,15 @@ function build_torn_function(
244245
push!(rhss, eq.rhs)
245246
end
246247

247-
s = structure(sys)
248-
@unpack fullvars = s
249-
var_eq_matching, var_sccs = algebraic_variables_scc(sys)
248+
state = get_or_construct_tearing_state(sys)
249+
fullvars = state.fullvars
250+
var_eq_matching, var_sccs = algebraic_variables_scc(state)
250251
condensed_graph = MatchedCondensationGraph(
251-
DiCMOBiGraph{true}(complete(s.graph), complete(var_eq_matching)), var_sccs)
252+
DiCMOBiGraph{true}(complete(state.structure.graph), complete(var_eq_matching)), var_sccs)
252253
toporder = topological_sort_by_dfs(condensed_graph)
253254
var_sccs = var_sccs[toporder]
254255

255-
states_idxs = collect(diffvars_range(s))
256+
states_idxs = collect(diffvars_range(state.structure))
256257
mass_matrix_diag = ones(length(states_idxs))
257258

258259
assignments, deps, sol_states = tearing_assignments(sys)
@@ -276,7 +277,7 @@ function build_torn_function(
276277
torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
277278
isempty(torn_eqs_idxs) && continue
278279
if length(torn_eqs_idxs) <= max_inlining_size
279-
nlsolve_expr = gen_nlsolve!(is_not_prepended_assignment, eqs[torn_eqs_idxs], s.fullvars[torn_vars_idxs], defs, assignments, (deps, invdeps), var2assignment, checkbounds=checkbounds)
280+
nlsolve_expr = gen_nlsolve!(is_not_prepended_assignment, eqs[torn_eqs_idxs], fullvars[torn_vars_idxs], defs, assignments, (deps, invdeps), var2assignment, checkbounds=checkbounds)
280281
append!(torn_expr, nlsolve_expr)
281282
push!(nlsolve_scc_idxs, i)
282283
else
@@ -297,7 +298,7 @@ function build_torn_function(
297298
rhss
298299
)
299300

300-
states = s.fullvars[states_idxs]
301+
states = fullvars[states_idxs]
301302
syms = map(Symbol, states_idxs)
302303

303304
pre = get_postprocess_fbody(sys)
@@ -322,10 +323,10 @@ function build_torn_function(
322323
if expression
323324
expr, states
324325
else
325-
observedfun = let sys=sys, dict=Dict(), assignments=assignments, deps=(deps, invdeps), sol_states=sol_states, var2assignment=var2assignment
326+
observedfun = let state = state, dict=Dict(), assignments=assignments, deps=(deps, invdeps), sol_states=sol_states, var2assignment=var2assignment
326327
function generated_observed(obsvar, u, p, t)
327328
obs = get!(dict, value(obsvar)) do
328-
build_observed_function(sys, obsvar, var_eq_matching, var_sccs,
329+
build_observed_function(state, obsvar, var_eq_matching, var_sccs,
329330
assignments, deps, sol_states, var2assignment,
330331
checkbounds=checkbounds,
331332
)
@@ -336,7 +337,7 @@ function build_torn_function(
336337

337338
ODEFunction{true}(
338339
@RuntimeGeneratedFunction(expr),
339-
sparsity = jacobian_sparsity ? torn_system_with_nlsolve_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs) : nothing,
340+
sparsity = jacobian_sparsity ? torn_system_with_nlsolve_jacobian_sparsity(state, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs) : nothing,
340341
syms = syms,
341342
observed = observedfun,
342343
mass_matrix = mass_matrix,
@@ -362,7 +363,7 @@ function find_solve_sequence(sccs, vars)
362363
end
363364

364365
function build_observed_function(
365-
sys, ts, var_eq_matching, var_sccs,
366+
state, ts, var_eq_matching, var_sccs,
366367
assignments,
367368
deps,
368369
sol_states,
@@ -379,12 +380,14 @@ function build_observed_function(
379380
ts = Symbolics.scalarize.(value.(ts))
380381

381382
vars = Set()
383+
sys = state.sys
382384
foreach(Base.Fix1(vars!, vars), ts)
383385
ivs = independent_variables(sys)
384386
dep_vars = collect(setdiff(vars, ivs))
385387

386-
s = structure(sys)
387-
@unpack fullvars, graph = s
388+
fullvars = state.fullvars
389+
s = state.structure
390+
graph = s.graph
388391
diffvars = map(i->fullvars[i], diffvars_range(s))
389392
algvars = map(i->fullvars[i], algvars_range(s))
390393

@@ -416,8 +419,13 @@ function build_observed_function(
416419
if !isempty(subset)
417420
eqs = equations(sys)
418421

419-
torn_eqs = map(i->map(v->eqs[var_eq_matching[v]], var_sccs[i]), subset)
420-
torn_vars = map(i->map(v->fullvars[v], var_sccs[i]), subset)
422+
nested_torn_vars_idxs = []
423+
for iscc in subset
424+
torn_vars_idxs = Int[var for var in var_sccs[iscc] if var_eq_matching[var] !== unassigned]
425+
isempty(torn_vars_idxs) || push!(nested_torn_vars_idxs, torn_vars_idxs)
426+
end
427+
torn_eqs = [[eqs[var_eq_matching[i]] for i in idxs] for idxs in nested_torn_vars_idxs]
428+
torn_vars = [fullvars[idxs] for idxs in nested_torn_vars_idxs]
421429
u0map = defaults(sys)
422430
assignments = copy(assignments)
423431
solves = map(zip(torn_eqs, torn_vars)) do (eqs, vars)

0 commit comments

Comments
 (0)