Skip to content

Commit 57dffe6

Browse files
KenoYingboMa
andcommitted
WIP: Partial State Selection
This implements the basic outline of partial state selection. It is functional and resolves the example in #988, but we do not currently implement any tearing priorities or handling of linear subsystems, so it's pretty easy for state selection to end up selecting states that are singular. Co-authored-by: Yingbo Ma <[email protected]>
1 parent 09d2ed3 commit 57dffe6

File tree

7 files changed

+193
-45
lines changed

7 files changed

+193
-45
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ using SparseArrays
3636

3737
using NonlinearSolve
3838

39-
export tearing, dae_index_lowering, check_consistency
39+
export tearing, partial_state_selection, dae_index_lowering, check_consistency
4040
export build_torn_function, build_observed_function, ODAEProblem
4141
export sorted_incidence_matrix
4242

@@ -45,6 +45,7 @@ include("pantelides.jl")
4545
include("bipartite_tearing/modia_tearing.jl")
4646
include("tearing.jl")
4747
include("symbolics_tearing.jl")
48+
include("partial_state_selection.jl")
4849
include("codegen.jl")
4950

5051
end # module

src/structural_transformation/pantelides.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,14 @@ end
7272
Perform Pantelides algorithm.
7373
"""
7474
function pantelides!(sys::ODESystem; maxiters = 8000)
75+
find_solvables!(sys)
7576
s = structure(sys)
7677
# D(j) = assoc[j]
77-
@unpack graph, var_to_diff = s
78-
return (sys, pantelides!(graph, var_to_diff)...)
78+
@unpack graph, var_to_diff, solvable_graph = s
79+
return (sys, pantelides!(graph, solvable_graph, var_to_diff)...)
7980
end
8081

81-
function pantelides!(graph, var_to_diff; maxiters = 8000)
82+
function pantelides!(graph, solvable_graph, var_to_diff; maxiters = 8000)
8283
neqs = nsrcs(graph)
8384
nvars = nv(var_to_diff)
8485
vcolor = falses(nvars)
@@ -106,7 +107,7 @@ function pantelides!(graph, var_to_diff; maxiters = 8000)
106107
for var in eachindex(vcolor); vcolor[var] || continue
107108
# introduce a new variable
108109
nvars += 1
109-
add_vertex!(graph, DST)
110+
add_vertex!(graph, DST); add_vertex!(solvable_graph, DST)
110111
# the new variable is the derivative of `var`
111112

112113
add_edge!(var_to_diff, var, add_vertex!(var_to_diff))
@@ -116,13 +117,16 @@ function pantelides!(graph, var_to_diff; maxiters = 8000)
116117
for eq in eachindex(ecolor); ecolor[eq] || continue
117118
# introduce a new equation
118119
neqs += 1
119-
add_vertex!(graph, SRC)
120+
add_vertex!(graph, SRC); add_vertex!(solvable_graph, SRC)
120121
# the new equation is created by differentiating `eq`
121122
eq_diff = add_vertex!(eq_to_diff)
122123
add_edge!(eq_to_diff, eq, eq_diff)
123124
for var in 𝑠neighbors(graph, eq)
124125
add_edge!(graph, eq_diff, var)
125126
add_edge!(graph, eq_diff, var_to_diff[var])
127+
# If you have f(x) = 0, then the derivative is (∂f/∂x) ẋ = 0.
128+
# which is linear, thus solvable in ẋ.
129+
add_edge!(solvable_graph, eq_diff, var_to_diff[var])
126130
end
127131
end
128132

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
function partial_state_selection_graph!(sys::ODESystem)
2+
s = get_structure(sys)
3+
(s isa SystemStructure) || (sys = initialize_system_structure(sys))
4+
s = structure(sys)
5+
find_solvables!(sys)
6+
@set! s.graph = complete(s.graph)
7+
@set! sys.structure = s
8+
(sys, partial_state_selection_graph!(s.graph, s.solvable_graph, s.var_to_diff)...)
9+
end
10+
11+
struct SelectedState; end
12+
function partial_state_selection_graph!(graph, solvable_graph, var_to_diff)
13+
var_eq_matching, eq_to_diff = pantelides!(graph, solvable_graph, var_to_diff)
14+
eq_to_diff = complete(eq_to_diff)
15+
16+
eqlevel = map(1:nsrcs(graph)) do eq
17+
level = 0
18+
while eq_to_diff[eq] !== nothing
19+
eq = eq_to_diff[eq]
20+
level += 1
21+
end
22+
level
23+
end
24+
25+
varlevel = map(1:ndsts(graph)) do var
26+
level = 0
27+
while var_to_diff[var] !== nothing
28+
var = var_to_diff[var]
29+
level += 1
30+
end
31+
level
32+
end
33+
34+
all_selected_states = Int[]
35+
36+
level = 0
37+
level_vars = [var for var in 1:ndsts(graph) if varlevel[var] == 0 && invview(var_to_diff)[var] !== nothing]
38+
39+
# TODO: Is this actually useful or should we just compute another maximal matching?
40+
for var in 1:ndsts(graph)
41+
if !(var in level_vars)
42+
var_eq_matching[var] = unassigned
43+
end
44+
end
45+
46+
while level < maximum(eqlevel)
47+
var_eq_matching = tear_graph_modia(graph, solvable_graph;
48+
eqfilter = eq->eqlevel[eq] == level && invview(eq_to_diff)[eq] !== nothing,
49+
varfilter = var->(var in level_vars && !(var in all_selected_states)))
50+
for var in level_vars
51+
if var_eq_matching[var] === unassigned
52+
selected_state = invview(var_to_diff)[var]
53+
push!(all_selected_states, selected_state)
54+
#=
55+
# TODO: This is what the Matteson paper says, but it doesn't
56+
# quite seem to work.
57+
while selected_state !== nothing
58+
push!(all_selected_states, selected_state)
59+
selected_state = invview(var_to_diff)[selected_state]
60+
end
61+
=#
62+
end
63+
end
64+
level += 1
65+
level_vars = [var for var = 1:ndsts(graph) if varlevel[var] == level && invview(var_to_diff)[var] !== nothing]
66+
end
67+
68+
var_eq_matching = tear_graph_modia(graph, solvable_graph;
69+
varfilter = var->!(var in all_selected_states))
70+
var_eq_matching = Matching{Union{Unassigned, SelectedState}}(var_eq_matching)
71+
for var in all_selected_states
72+
var_eq_matching[var] = SelectedState()
73+
end
74+
return var_eq_matching, eq_to_diff
75+
end

src/structural_transformation/symbolics_tearing.jl

Lines changed: 88 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,49 @@ function tearing_sub(expr, dict, s)
33
s ? simplify(expr) : expr
44
end
55

6-
function tearing_reassemble(sys, var_eq_matching; simplify=false)
6+
function tearing_reassemble(sys, var_eq_matching, eq_to_diff=nothing; simplify=false)
77
s = structure(sys)
8-
@unpack fullvars, solvable_graph, graph = s
8+
@unpack fullvars, solvable_graph, var_to_diff, graph = s
99

1010
eqs = equations(sys)
1111

12+
### Add the differentiated equations and variables
13+
D = Differential(get_iv(sys))
14+
if length(fullvars) != length(var_to_diff)
15+
for i = (length(fullvars)+1):length(var_to_diff)
16+
push!(fullvars, D(fullvars[invview(var_to_diff)[i]]))
17+
end
18+
end
19+
20+
### Add the differentiated equations
21+
neweqs = copy(eqs)
22+
if eq_to_diff !== nothing
23+
eq_to_diff = complete(eq_to_diff)
24+
for i = (length(eqs)+1):length(eq_to_diff)
25+
eq = neweqs[invview(eq_to_diff)[i]]
26+
push!(neweqs, ModelingToolkit.expand_derivatives(0 ~ D(eq.rhs - eq.lhs)))
27+
end
28+
29+
### Replace derivatives of non-selected states by dumy derivatives
30+
dummy_subs = Dict()
31+
for var = 1:length(fullvars)
32+
invview(var_to_diff)[var] === nothing && continue
33+
if var_eq_matching[invview(var_to_diff)[var]] !== SelectedState()
34+
fullvar = fullvars[var]
35+
subst_fullvar = tearing_sub(fullvar, dummy_subs, simplify)
36+
dummy_subs[fullvar] = fullvars[var] = diff2term(unwrap(subst_fullvar))
37+
var_to_diff[invview(var_to_diff)[var]] = nothing
38+
end
39+
end
40+
neweqs = map(neweqs) do eq
41+
0 ~ tearing_sub(eq.rhs - eq.lhs, dummy_subs, simplify)
42+
end
43+
end
44+
1245
### extract partition information
1346
function solve_equation(ieq, iv)
1447
var = fullvars[iv]
15-
eq = eqs[ieq]
48+
eq = neweqs[ieq]
1649
rhs = value(solve_for(eq, var; simplify=simplify, check=false))
1750

1851
if var in vars(rhs)
@@ -30,7 +63,7 @@ function tearing_reassemble(sys, var_eq_matching; simplify=false)
3063
end
3164
var => rhs
3265
end
33-
is_solvable(eq, iv) = eq !== unassigned && BipartiteEdge(eq, iv) in solvable_graph
66+
is_solvable(eq, iv) = isa(eq, Int) && BipartiteEdge(eq, iv) in solvable_graph
3467

3568
solved_equations = Int[]
3669
solved_variables = Int[]
@@ -41,32 +74,31 @@ function tearing_reassemble(sys, var_eq_matching; simplify=false)
4174
push!(solved_equations, ieq); push!(solved_variables, iv)
4275
end
4376

77+
isdiffvar(var) = invview(var_to_diff)[var] !== nothing && var_eq_matching[invview(var_to_diff)[var]] === SelectedState()
4478
solved = Dict(solve_equation(ieq, iv) for (ieq, iv) in zip(solved_equations, solved_variables))
4579
obseqs = [var ~ rhs for (var, rhs) in solved]
4680

4781
# Rewrite remaining equations in terms of solved variables
4882
function substitute_equation(ieq)
49-
eq = eqs[ieq]
50-
if isdiffeq(eq)
51-
return eq.lhs ~ tearing_sub(eq.rhs, solved, simplify)
52-
else
53-
if !(eq.lhs isa Number && eq.lhs == 0)
54-
eq = 0 ~ eq.rhs - eq.lhs
55-
end
56-
rhs = tearing_sub(eq.rhs, solved, simplify)
57-
if rhs isa Symbolic
58-
return 0 ~ rhs
59-
else # a number
60-
if abs(rhs) > 100eps(float(rhs))
61-
@warn "The equation $eq is not consistent. It simplifed to 0 == $rhs."
62-
end
63-
return nothing
83+
eq = neweqs[ieq]
84+
if !(eq.lhs isa Number && eq.lhs == 0)
85+
eq = 0 ~ eq.rhs - eq.lhs
86+
end
87+
rhs = tearing_sub(eq.rhs, solved, simplify)
88+
if rhs isa Symbolic
89+
return 0 ~ rhs
90+
else # a number
91+
if abs(rhs) > 100eps(float(rhs))
92+
@warn "The equation $eq is not consistent. It simplifed to 0 == $rhs."
6493
end
94+
return nothing
6595
end
6696
end
6797

68-
neweqs = Any[substitute_equation(ieq) for ieq in 1:length(eqs) if !(ieq in solved_equations)]
98+
diffeqs = [fullvars[iv] ~ tearing_sub(solved[fullvars[iv]], solved, simplify) for iv in solved_variables if isdiffvar(iv)]
99+
neweqs = Any[substitute_equation(ieq) for ieq in 1:length(neweqs) if !(ieq in solved_equations)]
69100
filter!(!isnothing, neweqs)
101+
prepend!(neweqs, diffeqs)
70102

71103
# Contract the vertices in the structure graph to make the structure match
72104
# the new reality of the system we've just created.
@@ -78,30 +110,17 @@ function tearing_reassemble(sys, var_eq_matching; simplify=false)
78110

79111
@set! s.graph = graph
80112
@set! s.fullvars = [v for (i, v) in enumerate(fullvars) if i in active_vars]
113+
@set! s.var_to_diff = DiffGraph(Union{Int, Nothing}[v for (i, v) in enumerate(s.var_to_diff) if i in active_vars])
81114
@set! s.vartype = [v for (i, v) in enumerate(s.vartype) if i in active_vars]
82115
@set! s.algeqs = [e for (i, e) in enumerate(s.algeqs) if i in active_eqs]
83116

84117
@set! sys.structure = s
85118
@set! sys.eqs = neweqs
86-
@set! sys.states = [s.fullvars[idx] for idx in 1:length(s.fullvars) if !isdervar(s, idx)]
119+
@set! sys.states = [fullvars[i] for i in active_vars]
87120
@set! sys.observed = [observed(sys); obseqs]
88121
return sys
89122
end
90123

91-
"""
92-
tearing(sys; simplify=false)
93-
94-
Tear the nonlinear equations in system. When `simplify=true`, we simplify the
95-
new residual residual equations after tearing. End users are encouraged to call [`structural_simplify`](@ref)
96-
instead, which calls this function internally.
97-
"""
98-
function tearing(sys; simplify=false)
99-
sys = init_for_tearing(sys)
100-
var_eq_matching = tear_graph(sys)
101-
102-
tearing_reassemble(sys, var_eq_matching; simplify=simplify)
103-
end
104-
105124
function init_for_tearing(sys)
106125
s = get_structure(sys)
107126
if !(s isa SystemStructure)
@@ -119,6 +138,38 @@ end
119138
function tear_graph(sys)
120139
s = structure(sys)
121140
@unpack graph, solvable_graph = s
122-
tear_graph_modia(graph, solvable_graph;
123-
varfilter=var->isalgvar(s, var), eqfilter=eq->s.algeqs[eq])
141+
var_eq_matching = Matching{Union{Unassigned, SelectedState}}(tear_graph_modia(graph, solvable_graph;
142+
varfilter=var->isalgvar(s, var), eqfilter=eq->s.algeqs[eq]))
143+
for var in 1:ndsts(graph)
144+
if !isalgvar(s, var)
145+
var_eq_matching[var] = SelectedState()
146+
end
147+
end
148+
var_eq_matching
149+
end
150+
151+
"""
152+
tearing(sys; simplify=false)
153+
154+
Tear the nonlinear equations in system. When `simplify=true`, we simplify the
155+
new residual residual equations after tearing. End users are encouraged to call [`structural_simplify`](@ref)
156+
instead, which calls this function internally.
157+
"""
158+
function tearing(sys; simplify=false)
159+
sys = init_for_tearing(sys)
160+
var_eq_matching = tear_graph(sys)
161+
162+
tearing_reassemble(sys, var_eq_matching; simplify=simplify)
163+
end
164+
165+
"""
166+
tearing(sys; simplify=false)
167+
168+
Perform partial state selection and tearing.
169+
"""
170+
function partial_state_selection(sys; simplify=false)
171+
sys = init_for_tearing(sys)
172+
sys, var_eq_matching, eq_to_diff = partial_state_selection_graph!(sys)
173+
174+
tearing_reassemble(sys, var_eq_matching, eq_to_diff; simplify=simplify)
124175
end

src/structural_transformation/utils.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,8 @@ function find_solvables!(sys)
164164
eqs = equations(sys)
165165
empty!(solvable_graph)
166166
for (i, eq) in enumerate(eqs)
167-
isdiffeq(eq) && continue
168167
term = value(eq.rhs - eq.lhs)
169168
for j in 𝑠neighbors(graph, i)
170-
isalgvar(s, j) || continue
171169
var = fullvars[j]
172170
isinput(var) && continue
173171
a, b, islinear = linear_expansion(term, var)

src/systems/systemstructure.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ struct DiffGraph <: Graphs.AbstractGraph{Int}
6262
primal_to_diff::Vector{Union{Int, Nothing}}
6363
diff_to_primal::Union{Nothing, Vector{Union{Int, Nothing}}}
6464
end
65+
DiffGraph(primal_to_diff::Vector{Union{Int, Nothing}}) =
66+
DiffGraph(primal_to_diff, nothing)
6567
DiffGraph(n::Integer, with_badj::Bool=false) = DiffGraph(Union{Int, Nothing}[nothing for _=1:n],
6668
with_badj ? Union{Int, Nothing}[nothing for _=1:n] : nothing)
6769

test/structural_transformation/index_reduction.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,20 @@ p = [
134134
prob_auto = ODEProblem(new_sys,u0,(0.0,10.0),p)
135135
sol = solve(prob_auto, Rodas5());
136136
#plot(sol, vars=(D(x), y))
137+
138+
let pss_pendulum2 = partial_state_selection(pendulum2)
139+
@test length(equations(pss_pendulum2)) == 3
140+
@test length(equations(ModelingToolkit.ode_order_lowering(pss_pendulum2))) == 4
141+
end
142+
143+
eqs = [D(x) ~ w,
144+
D(y) ~ z,
145+
D(w) ~ T*x,
146+
D(z) ~ T*y - g,
147+
0 ~ x^2 + y^2 - L^2]
148+
pendulum = ODESystem(eqs, t, [x, y, w, z, T], [L, g], name=:pendulum)
149+
150+
let pss_pendulum = partial_state_selection(pendulum)
151+
@test length(equations(pss_pendulum)) == 3
152+
@test length(equations(ModelingToolkit.ode_order_lowering(pss_pendulum))) == 4
153+
end

0 commit comments

Comments
 (0)