Skip to content

Commit ab25ef5

Browse files
refactor: remove partial_state_selection
1 parent 368d2f7 commit ab25ef5

File tree

4 files changed

+1
-202
lines changed

4 files changed

+1
-202
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ using SimpleNonlinearSolve
5555

5656
using DocStringExtensions
5757

58-
export tearing, partial_state_selection, dae_index_lowering, check_consistency
58+
export tearing, dae_index_lowering, check_consistency
5959
export dummy_derivative
6060
export sorted_incidence_matrix,
6161
pantelides!, pantelides_reassemble, tearing_reassemble, find_solvables!,

src/structural_transformation/partial_state_selection.jl

Lines changed: 0 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -1,173 +1,4 @@
1-
function partial_state_selection_graph!(state::TransformationState)
2-
find_solvables!(state; allow_symbolic = true)
3-
var_eq_matching = complete(pantelides!(state))
4-
complete!(state.structure)
5-
partial_state_selection_graph!(state.structure, var_eq_matching)
6-
end
7-
8-
function ascend_dg(xs, dg, level)
9-
while level > 0
10-
xs = Int[dg[x] for x in xs]
11-
level -= 1
12-
end
13-
return xs
14-
end
15-
16-
function ascend_dg_all(xs, dg, level, maxlevel)
17-
r = Int[]
18-
while true
19-
if level <= 0
20-
append!(r, xs)
21-
end
22-
maxlevel <= 0 && break
23-
xs = Int[dg[x] for x in xs if dg[x] !== nothing]
24-
level -= 1
25-
maxlevel -= 1
26-
end
27-
return r
28-
end
29-
30-
function pss_graph_modia!(structure::SystemStructure, maximal_top_matching, varlevel,
31-
inv_varlevel, inv_eqlevel)
32-
@unpack eq_to_diff, var_to_diff, graph, solvable_graph = structure
33-
34-
# var_eq_matching is a maximal matching on the top-differentiated variables.
35-
# Find Strongly connected components. Note that after pantelides, we expect
36-
# a balanced system, so a maximal matching should be possible.
37-
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, maximal_top_matching)
38-
var_eq_matching = Matching{Union{Unassigned, SelectedState}}(ndsts(graph))
39-
for vars in var_sccs
40-
# TODO: We should have a way to not have the scc code look at unassigned vars.
41-
if length(vars) == 1 && maximal_top_matching[vars[1]] === unassigned
42-
continue
43-
end
44-
45-
# Now proceed level by level from lowest to highest and tear the graph.
46-
eqs = [maximal_top_matching[var]
47-
for var in vars if maximal_top_matching[var] !== unassigned]
48-
isempty(eqs) && continue
49-
maxeqlevel = maximum(map(x -> inv_eqlevel[x], eqs))
50-
maxvarlevel = level = maximum(map(x -> inv_varlevel[x], vars))
51-
old_level_vars = ()
52-
ict = IncrementalCycleTracker(
53-
DiCMOBiGraph{true}(graph,
54-
complete(Matching(ndsts(graph)), nsrcs(graph))),
55-
dir = :in)
56-
57-
while level >= 0
58-
to_tear_eqs_toplevel = filter(eq -> inv_eqlevel[eq] >= level, eqs)
59-
to_tear_eqs = ascend_dg(to_tear_eqs_toplevel, invview(eq_to_diff), level)
60-
61-
to_tear_vars_toplevel = filter(var -> inv_varlevel[var] >= level, vars)
62-
to_tear_vars = ascend_dg(to_tear_vars_toplevel, invview(var_to_diff), level)
63-
64-
assigned_eqs = Int[]
65-
66-
if old_level_vars !== ()
67-
# Inherit constraints from previous level.
68-
# TODO: Is this actually a good idea or do we want full freedom
69-
# to tear differently on each level? Does it make a difference
70-
# whether we're using heuristic or optimal tearing?
71-
removed_eqs = Int[]
72-
removed_vars = Int[]
73-
for var in old_level_vars
74-
old_assign = var_eq_matching[var]
75-
if isa(old_assign, SelectedState)
76-
push!(removed_vars, var)
77-
continue
78-
elseif !isa(old_assign, Int) ||
79-
ict.graph.matching[var_to_diff[var]] !== unassigned
80-
continue
81-
end
82-
# Make sure the ict knows about this edge, so it doesn't accidentally introduce
83-
# a cycle.
84-
assgned_eq = eq_to_diff[old_assign]
85-
ok = try_assign_eq!(ict, var_to_diff[var], assgned_eq)
86-
@assert ok
87-
var_eq_matching[var_to_diff[var]] = assgned_eq
88-
push!(removed_eqs, eq_to_diff[ict.graph.matching[var]])
89-
push!(removed_vars, var_to_diff[var])
90-
push!(removed_vars, var)
91-
end
92-
to_tear_eqs = setdiff(to_tear_eqs, removed_eqs)
93-
to_tear_vars = setdiff(to_tear_vars, removed_vars)
94-
end
95-
tearEquations!(ict, solvable_graph.fadjlist, to_tear_eqs, BitSet(to_tear_vars),
96-
nothing)
97-
98-
for var in to_tear_vars
99-
@assert var_eq_matching[var] === unassigned
100-
assgned_eq = ict.graph.matching[var]
101-
var_eq_matching[var] = assgned_eq
102-
isa(assgned_eq, Int) && push!(assigned_eqs, assgned_eq)
103-
end
104-
105-
if level != 0
106-
remaining_vars = collect(v
107-
for v in to_tear_vars
108-
if var_eq_matching[v] === unassigned)
109-
if !isempty(remaining_vars)
110-
remaining_eqs = setdiff(to_tear_eqs, assigned_eqs)
111-
nlsolve_matching = maximal_matching(graph,
112-
Base.Fix2(in, remaining_eqs),
113-
Base.Fix2(in, remaining_vars))
114-
for var in remaining_vars
115-
if nlsolve_matching[var] === unassigned &&
116-
var_eq_matching[var] === unassigned
117-
var_eq_matching[var] = SelectedState()
118-
end
119-
end
120-
end
121-
end
122-
123-
old_level_vars = to_tear_vars
124-
level -= 1
125-
end
126-
end
127-
return complete(var_eq_matching, nsrcs(graph))
128-
end
129-
1301
struct SelectedState end
131-
function partial_state_selection_graph!(structure::SystemStructure, var_eq_matching)
132-
@unpack eq_to_diff, var_to_diff, graph, solvable_graph = structure
133-
eq_to_diff = complete(eq_to_diff)
134-
135-
inv_eqlevel = map(1:nsrcs(graph)) do eq
136-
level = 0
137-
while invview(eq_to_diff)[eq] !== nothing
138-
eq = invview(eq_to_diff)[eq]
139-
level += 1
140-
end
141-
level
142-
end
143-
144-
varlevel = map(1:ndsts(graph)) do var
145-
graph_level = level = 0
146-
while var_to_diff[var] !== nothing
147-
var = var_to_diff[var]
148-
level += 1
149-
if !isempty(𝑑neighbors(graph, var))
150-
graph_level = level
151-
end
152-
end
153-
graph_level
154-
end
155-
156-
inv_varlevel = map(1:ndsts(graph)) do var
157-
level = 0
158-
while invview(var_to_diff)[var] !== nothing
159-
var = invview(var_to_diff)[var]
160-
level += 1
161-
end
162-
level
163-
end
164-
165-
var_eq_matching = pss_graph_modia!(structure,
166-
complete(var_eq_matching), varlevel, inv_varlevel,
167-
inv_eqlevel)
168-
169-
var_eq_matching
170-
end
1712

1723
function dummy_derivative_graph!(state::TransformationState, jac = nothing;
1734
state_priority = nothing, log = Val(false), kwargs...)

test/state_selection.jl

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,6 @@ let dd = dummy_derivative(sys)
2020
@test length(unknowns(dd)) == length(equations(dd)) < 9
2121
end
2222

23-
@test_skip let pss = partial_state_selection(sys)
24-
@test length(equations(pss)) == 1
25-
@test length(unknowns(pss)) == 2
26-
end
27-
28-
@parameters σ ρ β
29-
@variables x(t) y(t) z(t) a(t) u(t) F(t)
30-
31-
eqs = [D(x) ~ σ * (y - x)
32-
D(y) ~ x *- z) - y + β
33-
0 ~ z - x + y
34-
0 ~ a + z
35-
u ~ z + a]
36-
37-
lorenz1 = System(eqs, t, name = :lorenz1)
38-
let al1 = alias_elimination(lorenz1)
39-
let lss = partial_state_selection(al1)
40-
@test length(equations(lss)) == 2
41-
end
42-
end
43-
4423
# 1516
4524
let
4625
@connector function Fluid_port(; name, p = 101325.0, m = 0.0, T = 293.15)

test/structural_transformation/index_reduction.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,13 @@ eqs2 = [D(D(x)) ~ T * x,
3131
0 ~ x^2 + y^2 - L^2]
3232
pendulum2 = System(eqs2, t, [x, y, T], [L, g], name = :pendulum)
3333

34-
@test_skip begin
35-
let pss_pendulum2 = partial_state_selection(pendulum2)
36-
length(equations(pss_pendulum2)) <= 6
37-
end
38-
end
39-
4034
eqs = [D(x) ~ w,
4135
D(y) ~ z,
4236
D(w) ~ T * x,
4337
D(z) ~ T * y - g,
4438
0 ~ x^2 + y^2 - L^2]
4539
pendulum = System(eqs, t, [x, y, w, z, T], [L, g], name = :pendulum)
4640

47-
let pss_pendulum = partial_state_selection(pendulum)
48-
# This currently selects `T` rather than `x` at top level. Needs tearing priorities to fix.
49-
@test_broken length(equations(pss_pendulum)) == 3
50-
end
51-
5241
let sys = mtkcompile(pendulum2)
5342
@test length(equations(sys)) == 5
5443
@test length(unknowns(sys)) == 5

0 commit comments

Comments
 (0)