Skip to content

Commit ba12d6b

Browse files
committed
Allow filtering which variables are considered differentiated
1 parent 2d03c65 commit ba12d6b

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

src/debug.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ end
2929
function Base.getindex(bgpm::SystemStructurePrintMatrix, i::Integer, j::Integer)
3030
checkbounds(bgpm, i, j)
3131
if i <= 1
32-
return (Label.(("#", "∂ₜ", " ", " eq", "", "#", "∂ₜ", " ", " v")))[j]
32+
return (Label.(("# eq", "∂ₜ", " ", " ", "", "# v", "∂ₜ", " ", " ")))[j]
3333
elseif j == 5
3434
colors = Base.text_colors
3535
return Label("|", :light_black)

src/pantelides.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@ case, there is one complicating condition:
1515
This function takes care of these complications are returns a boolean array
1616
for every variable, indicating whether it is considered "highest-differentiated".
1717
"""
18-
function computed_highest_diff_variables(structure)
18+
function computed_highest_diff_variables(structure, diffvars::Union{BitVector, BitSet, Nothing}=nothing)
1919
@unpack graph, var_to_diff = structure
2020

2121
nvars = length(var_to_diff)
2222
varwhitelist = falses(nvars)
2323
for var in 1:nvars
24+
_canchoose(diffvars, var) || continue
2425
if var_to_diff[var] === nothing && !varwhitelist[var]
2526
# This variable is structurally highest-differentiated, but may not actually appear in the
2627
# system (complication 1 above). Ascend the differentiation graph to find the highest
@@ -49,6 +50,9 @@ function computed_highest_diff_variables(structure)
4950

5051
return varwhitelist
5152
end
53+
_canchoose(diffvars::BitSet, var::Integer) = var in diffvars
54+
_canchoose(diffvars::BitVector, var::Integer) = diffvars[var]
55+
_canchoose(diffvars::Nothing, var::Integer) = true
5256

5357
"""
5458
pantelides!(state::TransformationState; kwargs...)

src/partial_state_selection.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,13 @@ function ascend_dg_all(xs, dg, level, maxlevel)
2828
return r
2929
end
3030

31-
function pss_graph_modia!(structure::SystemStructure, maximal_top_matching, varlevel,
32-
inv_varlevel, inv_eqlevel)
31+
struct DiffData
32+
varlevel::Vector{Int}
33+
inv_varlevel::Vector{Int}
34+
inv_eqlevel::Vector{Int}
35+
end
36+
37+
function pss_graph_modia!(structure::SystemStructure, maximal_top_matching, diff_data::Union{Nothing, DiffData}=nothing)
3338
@unpack eq_to_diff, var_to_diff, graph, solvable_graph = structure
3439

3540
# var_eq_matching is a maximal matching on the top-differentiated variables.
@@ -47,19 +52,19 @@ function pss_graph_modia!(structure::SystemStructure, maximal_top_matching, varl
4752
eqs = [maximal_top_matching[var]
4853
for var in vars if maximal_top_matching[var] !== unassigned]
4954
isempty(eqs) && continue
50-
maxeqlevel = maximum(map(x -> inv_eqlevel[x], eqs))
51-
maxvarlevel = level = maximum(map(x -> inv_varlevel[x], vars))
55+
maxeqlevel = diff_data === nothing ? 0 : maximum(map(x -> diff_data.inv_eqlevel[x], eqs))
56+
maxvarlevel = level = diff_data === nothing ? 0 : maximum(map(x -> diff_data.inv_varlevel[x], vars))
5257
old_level_vars = ()
5358
ict = IncrementalCycleTracker(
5459
DiCMOBiGraph{true}(graph,
5560
complete(Matching(ndsts(graph)), nsrcs(graph))),
5661
dir = :in)
5762

5863
while level >= 0
59-
to_tear_eqs_toplevel = filter(eq -> inv_eqlevel[eq] >= level, eqs)
64+
to_tear_eqs_toplevel = level == 0 ? eqs : filter(eq -> diff_data.inv_eqlevel[eq] >= level, eqs)
6065
to_tear_eqs = ascend_dg(to_tear_eqs_toplevel, invview(eq_to_diff), level)
6166

62-
to_tear_vars_toplevel = filter(var -> inv_varlevel[var] >= level, vars)
67+
to_tear_vars_toplevel = level == 0 ? vars : filter(var -> diff_data.inv_varlevel[var] >= level, vars)
6368
to_tear_vars = ascend_dg(to_tear_vars_toplevel, invview(var_to_diff), level)
6469

6570
assigned_eqs = Int[]
@@ -163,8 +168,8 @@ function partial_state_selection_graph!(structure::SystemStructure, var_eq_match
163168
end
164169

165170
var_eq_matching = pss_graph_modia!(structure,
166-
complete(var_eq_matching), varlevel, inv_varlevel,
167-
inv_eqlevel)
171+
complete(var_eq_matching), DiffData(varlevel, inv_varlevel,
172+
inv_eqlevel))
168173

169174
var_eq_matching
170175
end

0 commit comments

Comments
 (0)