Skip to content

Commit 85df648

Browse files
authored
Merge pull request #2192 from Keno/kf/revivepss
Revive `pss_graph_modia!`
2 parents 20b9203 + 5c0268b commit 85df648

File tree

2 files changed

+46
-28
lines changed

2 files changed

+46
-28
lines changed

src/structural_transformation/partial_state_selection.jl

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,36 +27,40 @@ function ascend_dg_all(xs, dg, level, maxlevel)
2727
return r
2828
end
2929

30-
function pss_graph_modia!(structure::SystemStructure, var_eq_matching, varlevel,
30+
function pss_graph_modia!(structure::SystemStructure, maximal_top_matching, varlevel,
3131
inv_varlevel, inv_eqlevel)
3232
@unpack eq_to_diff, var_to_diff, graph, solvable_graph = structure
3333

3434
# var_eq_matching is a maximal matching on the top-differentiated variables.
3535
# Find Strongly connected components. Note that after pantelides, we expect
3636
# a balanced system, so a maximal matching should be possible.
37-
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, var_eq_matching)
38-
var_eq_matching = Matching{Union{Unassigned, SelectedState}}(var_eq_matching)
37+
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, maximal_top_matching)
38+
var_eq_matching = Matching{Union{Unassigned, SelectedState}}(ndsts(graph))
3939
for vars in var_sccs
4040
# TODO: We should have a way to not have the scc code look at unassigned vars.
41-
if length(vars) == 1 && varlevel[vars[1]] != 0
41+
if length(vars) == 1 && maximal_top_matching[vars[1]] === unassigned
4242
continue
4343
end
4444

4545
# Now proceed level by level from lowest to highest and tear the graph.
46-
eqs = [var_eq_matching[var] for var in vars if var_eq_matching[var] !== unassigned]
46+
eqs = [maximal_top_matching[var]
47+
for var in vars if maximal_top_matching[var] !== unassigned]
4748
isempty(eqs) && continue
48-
maxlevel = level = maximum(map(x -> inv_eqlevel[x], eqs))
49+
maxeqlevel = maximum(map(x -> inv_eqlevel[x], eqs))
50+
maxvarlevel = level = maximum(map(x -> inv_varlevel[x], vars))
4951
old_level_vars = ()
5052
ict = IncrementalCycleTracker(DiCMOBiGraph{true}(graph,
5153
complete(Matching(ndsts(graph))));
5254
dir = :in)
55+
5356
while level >= 0
5457
to_tear_eqs_toplevel = filter(eq -> inv_eqlevel[eq] >= level, eqs)
5558
to_tear_eqs = ascend_dg(to_tear_eqs_toplevel, invview(eq_to_diff), level)
5659

5760
to_tear_vars_toplevel = filter(var -> inv_varlevel[var] >= level, vars)
58-
to_tear_vars = ascend_dg_all(to_tear_vars_toplevel, invview(var_to_diff), level,
59-
maxlevel)
61+
to_tear_vars = ascend_dg(to_tear_vars_toplevel, invview(var_to_diff), level)
62+
63+
assigned_eqs = Int[]
6064

6165
if old_level_vars !== ()
6266
# Inherit constraints from previous level.
@@ -66,45 +70,59 @@ function pss_graph_modia!(structure::SystemStructure, var_eq_matching, varlevel,
6670
removed_eqs = Int[]
6771
removed_vars = Int[]
6872
for var in old_level_vars
69-
old_assign = ict.graph.matching[var]
70-
if !isa(old_assign, Int) ||
71-
ict.graph.matching[var_to_diff[var]] !== unassigned
73+
old_assign = var_eq_matching[var]
74+
if isa(old_assign, SelectedState)
75+
push!(removed_vars, var)
76+
continue
77+
elseif !isa(old_assign, Int) ||
78+
ict.graph.matching[var_to_diff[var]] !== unassigned
7279
continue
7380
end
7481
# Make sure the ict knows about this edge, so it doesn't accidentally introduce
7582
# a cycle.
76-
ok = try_assign_eq!(ict, var_to_diff[var], eq_to_diff[old_assign])
83+
assgned_eq = eq_to_diff[old_assign]
84+
ok = try_assign_eq!(ict, var_to_diff[var], assgned_eq)
7785
@assert ok
78-
var_eq_matching[var_to_diff[var]] = eq_to_diff[old_assign]
86+
var_eq_matching[var_to_diff[var]] = assgned_eq
7987
push!(removed_eqs, eq_to_diff[ict.graph.matching[var]])
8088
push!(removed_vars, var_to_diff[var])
89+
push!(removed_vars, var)
8190
end
8291
to_tear_eqs = setdiff(to_tear_eqs, removed_eqs)
8392
to_tear_vars = setdiff(to_tear_vars, removed_vars)
8493
end
85-
filter!(var -> ict.graph.matching[var] === unassigned, to_tear_vars)
86-
filter!(eq -> invview(ict.graph.matching)[eq] === unassigned, to_tear_eqs)
8794
tearEquations!(ict, solvable_graph.fadjlist, to_tear_eqs, BitSet(to_tear_vars),
8895
nothing)
96+
8997
for var in to_tear_vars
90-
var_eq_matching[var] = unassigned
98+
@assert var_eq_matching[var] === unassigned
99+
assgned_eq = ict.graph.matching[var]
100+
var_eq_matching[var] = assgned_eq
101+
isa(assgned_eq, Int) && push!(assigned_eqs, assgned_eq)
91102
end
92-
for var in to_tear_vars
93-
var_eq_matching[var] = ict.graph.matching[var]
103+
104+
if level != 0
105+
remaining_vars = collect(v for v in to_tear_vars
106+
if var_eq_matching[v] === unassigned)
107+
if !isempty(remaining_vars)
108+
remaining_eqs = setdiff(to_tear_eqs, assigned_eqs)
109+
nlsolve_matching = maximal_matching(graph,
110+
Base.Fix2(in, remaining_eqs),
111+
Base.Fix2(in, remaining_vars))
112+
for var in remaining_vars
113+
if nlsolve_matching[var] === unassigned &&
114+
var_eq_matching[var] === unassigned
115+
var_eq_matching[var] = SelectedState()
116+
end
117+
end
118+
end
94119
end
120+
95121
old_level_vars = to_tear_vars
96122
level -= 1
97123
end
98124
end
99-
for var in 1:ndsts(graph)
100-
dv = var_to_diff[var]
101-
# If `var` is not algebraic (not differentiated nor a dummy derivative),
102-
# then it's a SelectedState
103-
if !(dv === nothing || (varlevel[dv] !== 0 && var_eq_matching[dv] === unassigned))
104-
var_eq_matching[var] = SelectedState()
105-
end
106-
end
107-
return var_eq_matching
125+
return complete(var_eq_matching)
108126
end
109127

110128
struct SelectedState end

test/structural_transformation/index_reduction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ sol = solve(prob_auto, Rodas5());
118118
#plot(sol, idxs=(D(x), y))
119119

120120
let pss_pendulum2 = partial_state_selection(pendulum2)
121-
@test_broken length(equations(pss_pendulum2)) <= 6
121+
@test length(equations(pss_pendulum2)) <= 6
122122
end
123123

124124
eqs = [D(x) ~ w,

0 commit comments

Comments
 (0)