Skip to content

Commit 8ed38eb

Browse files
committed
Post State-Selection cleanup
1 parent e13068d commit 8ed38eb

20 files changed

+294
-309
lines changed

docs/src/tutorials/tearing_parallelism.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ investigate what this means:
211211

212212
```julia
213213
using ModelingToolkit.BipartiteGraphs
214-
big_rc = initialize_system_structure(big_rc)
215-
inc_org = BipartiteGraphs.incidence_matrix(structure(big_rc).graph)
214+
ts = TearingState(big_rc)
215+
inc_org = BipartiteGraphs.incidence_matrix(ts.graph)
216216
blt_org = StructuralTransformations.sorted_incidence_matrix(big_rc, only_algeqs=true, only_algvars=true)
217217
blt_reduced = StructuralTransformations.sorted_incidence_matrix(sys, only_algeqs=true, only_algvars=true)
218218
```

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ export calculate_factorized_W, generate_factorized_W
193193
export calculate_hessian, generate_hessian
194194
export calculate_massmatrix, generate_diffusion_function
195195
export stochastic_integral_transform
196-
export initialize_system_structure
196+
export TearingState, StateSelectionState
197197
export generate_difference_cb
198198

199199
export BipartiteGraph, equation_dependencies, variable_dependencies

src/structural_transformation/StructuralTransformations.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ using SymbolicUtils.Rewriters
1010
using SymbolicUtils: similarterm, istree
1111

1212
using ModelingToolkit
13-
using ModelingToolkit: ODESystem, AbstractSystem,var_from_nested_derivative, Differential,
13+
using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Differential,
1414
states, equations, vars, Symbolic, diff2term, value,
1515
operation, arguments, Sym, Term, simplify, solve_for,
1616
isdiffeq, isdifferential, isinput,
17-
get_structure, get_iv, independent_variables,
18-
get_structure, defaults, InvalidSystemException,
17+
get_iv, independent_variables,
18+
defaults, InvalidSystemException,
1919
ExtraEquationsSystemException,
2020
ExtraVariablesSystemException,
2121
get_postprocess_fbody, vars!,
@@ -25,6 +25,7 @@ using ModelingToolkit.BipartiteGraphs
2525
import .BipartiteGraphs: invview
2626
using Graphs
2727
using ModelingToolkit.SystemStructures
28+
using ModelingToolkit.SystemStructures: algeqs
2829

2930
using ModelingToolkit.DiffEqBase
3031
using ModelingToolkit.StaticArrays
@@ -38,7 +39,7 @@ using NonlinearSolve
3839

3940
export tearing, partial_state_selection, dae_index_lowering, check_consistency
4041
export build_torn_function, build_observed_function, ODAEProblem
41-
export sorted_incidence_matrix
42+
export sorted_incidence_matrix, pantelides!, tearing_reassemble
4243

4344
include("utils.jl")
4445
include("pantelides.jl")

src/structural_transformation/codegen.jl

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ using LinearAlgebra
22

33
const MAX_INLINE_NLSOLVE_SIZE = 8
44

5-
function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs)
6-
s = structure(sys)
7-
@unpack fullvars, graph = s
5+
function torn_system_jacobian_sparsity(state, var_eq_matching, var_sccs)
6+
fullvars = state.fullvars
7+
graph = state.structure.graph
88

99
# The sparsity pattern of `nlsolve(f, u, p)` w.r.t `p` is difficult to
1010
# determine in general. Consider the "simplest" case, a linear system. We
@@ -59,12 +59,12 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs)
5959
Graphs.insorted(var, v_residual) && continue
6060
deps = get(avars2dvars, var, nothing)
6161
if deps === nothing # differential variable
62-
@assert !isalgvar(s, var)
62+
@assert !isalgvar(state.structure, var)
6363
for tvar in v_residual
6464
push!(avars2dvars[tvar], var)
6565
end
6666
else # tearing variable from previous partitions
67-
@assert isalgvar(s, var)
67+
@assert isalgvar(state.structure, var)
6868
for tvar in v_residual
6969
union!(avars2dvars[tvar], avars2dvars[var])
7070
end
@@ -73,18 +73,19 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs)
7373
end
7474
end
7575

76-
dvrange = diffvars_range(s)
76+
dvrange = diffvars_range(state.structure)
7777
dvar2idx = Dict(v=>i for (i, v) in enumerate(dvrange))
7878
I = Int[]; J = Int[]
7979
eqidx = 0
80+
aeqs = algeqs(state.structure)
8081
for ieq in 𝑠vertices(graph)
81-
isalgeq(s, ieq) && continue
82+
ieq in aeqs && continue
8283
eqidx += 1
8384
for ivar in 𝑠neighbors(graph, ieq)
84-
if isdiffvar(s, ivar)
85+
if isdiffvar(state.structure, ivar)
8586
push!(I, eqidx)
8687
push!(J, dvar2idx[ivar])
87-
elseif isalgvar(s, ivar)
88+
elseif isalgvar(state.structure, ivar)
8889
for dvar in avars2dvars[ivar]
8990
push!(I, eqidx)
9091
push!(J, dvar2idx[dvar])
@@ -170,18 +171,18 @@ function build_torn_function(
170171
isdiffeq(eq) && push!(rhss, eq.rhs)
171172
end
172173

173-
s = structure(sys)
174-
@unpack fullvars = s
175-
var_eq_matching, var_sccs = algebraic_variables_scc(sys)
174+
state = TearingState(sys)
175+
fullvars = state.fullvars
176+
var_eq_matching, var_sccs = algebraic_variables_scc(state)
176177

177-
states = map(i->s.fullvars[i], diffvars_range(s))
178+
states = map(i->fullvars[i], diffvars_range(state.structure))
178179
mass_matrix_diag = ones(length(states))
179180
torn_expr = []
180181
defs = defaults(sys)
181182

182183
needs_extending = false
183184
for scc in var_sccs
184-
torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
185+
torn_vars = [fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
185186
torn_eqs = [eqs[var_eq_matching[var]] for var in scc if var_eq_matching[var] !== unassigned]
186187
isempty(torn_eqs) && continue
187188
if length(torn_eqs) <= max_inlining_size
@@ -224,18 +225,18 @@ function build_torn_function(
224225
if expression
225226
expr, states
226227
else
227-
observedfun = let sys = sys, dict = Dict()
228+
observedfun = let state = state, dict = Dict()
228229
function generated_observed(obsvar, u, p, t)
229230
obs = get!(dict, value(obsvar)) do
230-
build_observed_function(sys, obsvar, var_eq_matching, var_sccs, checkbounds=checkbounds)
231+
build_observed_function(state, obsvar, var_eq_matching, var_sccs, checkbounds=checkbounds)
231232
end
232233
obs(u, p, t)
233234
end
234235
end
235236

236237
ODEFunction{true}(
237238
@RuntimeGeneratedFunction(expr),
238-
sparsity = torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs),
239+
sparsity = torn_system_jacobian_sparsity(state, var_eq_matching, var_sccs),
239240
syms = syms,
240241
observed = observedfun,
241242
mass_matrix = mass_matrix,
@@ -261,7 +262,7 @@ function find_solve_sequence(sccs, vars)
261262
end
262263

263264
function build_observed_function(
264-
sys, ts, var_eq_matching, var_sccs;
265+
state, ts, var_eq_matching, var_sccs;
265266
expression=false,
266267
output_type=Array,
267268
checkbounds=true
@@ -273,12 +274,14 @@ function build_observed_function(
273274
ts = Symbolics.scalarize.(value.(ts))
274275

275276
vars = Set()
277+
sys = state.sys
276278
foreach(Base.Fix1(vars!, vars), ts)
277279
ivs = independent_variables(sys)
278280
dep_vars = collect(setdiff(vars, ivs))
279281

280-
s = structure(sys)
281-
@unpack fullvars, graph = s
282+
fullvars = state.fullvars
283+
s = state.structure
284+
graph = s.graph
282285
diffvars = map(i->fullvars[i], diffvars_range(s))
283286
algvars = map(i->fullvars[i], algvars_range(s))
284287

src/structural_transformation/pantelides.jl

Lines changed: 17 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
### Reassemble: structural information -> system
33
###
44

5-
function pantelides_reassemble(sys::ODESystem, eq_to_diff, assign)
6-
s = structure(sys)
7-
@unpack fullvars, var_to_diff = s
5+
function pantelides_reassemble(state::TearingState, var_eq_matching)
6+
fullvars = state.fullvars
7+
@unpack var_to_diff, eq_to_diff = state.structure
8+
sys = state.sys
89
# Step 1: write derivative equations
910
in_eqs = equations(sys)
1011
out_eqs = Vector{Any}(undef, nv(eq_to_diff))
@@ -58,53 +59,29 @@ function pantelides_reassemble(sys::ODESystem, eq_to_diff, assign)
5859
end
5960

6061
final_vars = unique(filter(x->!(operation(x) isa Differential), fullvars))
61-
final_eqs = map(identity, filter(x->value(x.lhs) !== nothing, out_eqs[sort(filter(x->x !== unassigned, assign))]))
62+
final_eqs = map(identity, filter(x->value(x.lhs) !== nothing, out_eqs[sort(filter(x->x !== unassigned, var_eq_matching))]))
6263

6364
@set! sys.eqs = final_eqs
6465
@set! sys.states = final_vars
65-
@set! sys.structure = nothing
6666
return sys
6767
end
6868

6969
"""
70-
pantelides!(sys::ODESystem; kwargs...)
70+
pantelides!(state::TransformationState; kwargs...)
7171
7272
Perform Pantelides algorithm.
7373
"""
74-
function pantelides!(sys::ODESystem; maxiters = 8000)
75-
find_solvables!(sys)
76-
s = structure(sys)
77-
# D(j) = assoc[j]
78-
@unpack graph, var_to_diff = s
79-
# N.B.: var_derivative! and eq_derivative! are defined in symbolics_tearing.jl
80-
return (sys, pantelides!(PantelidesSetup(sys, graph, var_to_diff))...)
81-
end
82-
83-
struct PantelidesSetup{T}
84-
system::T
85-
graph::BipartiteGraph
86-
var_to_diff::DiffGraph
87-
eq_to_diff::DiffGraph
88-
var_eq_matching::Matching
89-
end
90-
91-
function PantelidesSetup(sys::T, graph, var_to_diff) where {T}
92-
neqs = nsrcs(graph)
93-
nvars = nv(var_to_diff)
94-
var_eq_matching = Matching(nvars)
95-
eq_to_diff = DiffGraph(neqs)
96-
PantelidesSetup{T}(sys, graph, var_to_diff, eq_to_diff, var_eq_matching)
97-
end
98-
99-
function pantelides!(p::PantelidesSetup; maxiters = 8000)
100-
@unpack graph, var_to_diff, eq_to_diff, var_eq_matching = p
74+
function pantelides!(state::TransformationState; maxiters = 8000)
75+
@unpack graph, var_to_diff, eq_to_diff = state.structure
10176
neqs = nsrcs(graph)
10277
nvars = nv(var_to_diff)
10378
vcolor = falses(nvars)
10479
ecolor = falses(neqs)
80+
var_eq_matching = Matching(nvars)
10581
neqs′ = neqs
10682
for k in 1:neqs′
10783
eq′ = k
84+
isempty(𝑠neighbors(graph, eq′)) && continue
10885
pathfound = false
10986
# In practice, `maxiters=8000` should never be reached, otherwise, the
11087
# index would be on the order of thousands.
@@ -128,7 +105,7 @@ function pantelides!(p::PantelidesSetup; maxiters = 8000)
128105

129106
add_edge!(var_to_diff, var, add_vertex!(var_to_diff))
130107
push!(var_eq_matching, unassigned)
131-
var_derivative!(p, eq)
108+
var_derivative!(state, var)
132109
end
133110

134111
for eq in eachindex(ecolor); ecolor[eq] || continue
@@ -138,7 +115,7 @@ function pantelides!(p::PantelidesSetup; maxiters = 8000)
138115
# the new equation is created by differentiating `eq`
139116
eq_diff = add_vertex!(eq_to_diff)
140117
add_edge!(eq_to_diff, eq, eq_diff)
141-
eq_derivative!(p, eq)
118+
eq_derivative!(state, eq)
142119
end
143120

144121
for var in eachindex(vcolor); vcolor[var] || continue
@@ -150,7 +127,7 @@ function pantelides!(p::PantelidesSetup; maxiters = 8000)
150127
end # for _ in 1:maxiters
151128
pathfound || error("maxiters=$maxiters reached! File a bug report if your system has a reasonable index (<100), and you are using the default `maxiters`. Try to increase the maxiters by `pantelides(sys::ODESystem; maxiters=1_000_000)` if your system has an incredibly high index and it is truly extremely large.")
152129
end # for k in 1:neqs′
153-
return var_eq_matching, eq_to_diff
130+
return var_eq_matching
154131
end
155132

156133
"""
@@ -161,8 +138,8 @@ DAE. `kwargs` are forwarded to [`pantelides!`](@ref). End users are encouraged t
161138
instead, which calls this function internally.
162139
"""
163140
function dae_index_lowering(sys::ODESystem; kwargs...)
164-
s = get_structure(sys)
165-
(s isa SystemStructure) || (sys = initialize_system_structure(sys))
166-
sys, var_eq_matching, eq_to_diff = pantelides!(sys; kwargs...)
167-
return pantelides_reassemble(sys, eq_to_diff, var_eq_matching)
141+
state = TearingState(sys)
142+
find_solvables!(state)
143+
var_eq_matching = pantelides!(state; kwargs...)
144+
return pantelides_reassemble(state, var_eq_matching)
168145
end

src/structural_transformation/partial_state_selection.jl

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
function partial_state_selection_graph!(sys::ODESystem)
2-
s = get_structure(sys)
3-
(s isa SystemStructure) || (sys = initialize_system_structure(sys))
4-
s = structure(sys)
5-
find_solvables!(sys; allow_symbolic=true)
6-
@set! s.graph = complete(s.graph)
7-
@set! sys.structure = s
8-
var_eq_matching, eq_to_diff = pantelides!(PantelidesSetup(sys, s.graph, s.var_to_diff))
9-
(sys, partial_state_selection_graph!(s.graph, s.solvable_graph, s.var_to_diff, var_eq_matching, eq_to_diff)...)
1+
function partial_state_selection_graph!(state::TransformationState)
2+
find_solvables!(state; allow_symbolic=true)
3+
var_eq_matching = complete(pantelides!(state))
4+
complete!(state.structure)
5+
partial_state_selection_graph!(state.structure, var_eq_matching)
106
end
117

128
function ascend_dg(xs, dg, level)
@@ -31,7 +27,9 @@ function ascend_dg_all(xs, dg, level, maxlevel)
3127
return r
3228
end
3329

34-
function pss_graph_modia!(graph, solvable_graph, var_eq_matching, var_to_diff, eq_to_diff, varlevel, inv_varlevel, inv_eqlevel)
30+
function pss_graph_modia!(structure::SystemStructure, var_eq_matching, varlevel, inv_varlevel, inv_eqlevel)
31+
@unpack eq_to_diff, var_to_diff, graph, solvable_graph = structure
32+
3533
# var_eq_matching is a maximal matching on the top-differentiated variables.
3634
# Find Strongly connected components. Note that after pantelides, we expect
3735
# a balanced system, so a maximal matching should be possible.
@@ -44,7 +42,8 @@ function pss_graph_modia!(graph, solvable_graph, var_eq_matching, var_to_diff, e
4442
end
4543

4644
# Now proceed level by level from lowest to highest and tear the graph.
47-
eqs = [var_eq_matching[var] for var in vars]
45+
eqs = [var_eq_matching[var] for var in vars if var_eq_matching[var] !== unassigned]
46+
isempty(eqs) && continue
4847
maxlevel = level = maximum(map(x->inv_eqlevel[x], eqs))
4948
old_level_vars = ()
5049
ict = IncrementalCycleTracker(DiCMOBiGraph{true}(graph, complete(Matching(ndsts(graph)))); dir=:in)
@@ -97,7 +96,8 @@ function pss_graph_modia!(graph, solvable_graph, var_eq_matching, var_to_diff, e
9796
end
9897

9998
struct SelectedState; end
100-
function partial_state_selection_graph!(graph, solvable_graph, var_to_diff, var_eq_matching, eq_to_diff)
99+
function partial_state_selection_graph!(structure::SystemStructure, var_eq_matching)
100+
@unpack eq_to_diff, var_to_diff, graph, solvable_graph = structure
101101
eq_to_diff = complete(eq_to_diff)
102102

103103
inv_eqlevel = map(1:nsrcs(graph)) do eq
@@ -134,9 +134,8 @@ function partial_state_selection_graph!(graph, solvable_graph, var_to_diff, var_
134134
end
135135
end
136136

137-
var_eq_matching = pss_graph_modia!(graph, solvable_graph,
138-
complete(var_eq_matching), var_to_diff, eq_to_diff, varlevel, inv_varlevel,
139-
inv_eqlevel)
137+
var_eq_matching = pss_graph_modia!(structure,
138+
complete(var_eq_matching), varlevel, inv_varlevel, inv_eqlevel)
140139

141-
var_eq_matching, eq_to_diff
140+
var_eq_matching
142141
end

0 commit comments

Comments
 (0)