Skip to content

Commit 83d6cee

Browse files
authored
Merge pull request #1843 from SciML/myb/complete_ag
Compute `complete_ag` that contains irreducible alias information
2 parents aed17f4 + e943078 commit 83d6cee

File tree

6 files changed

+65
-17
lines changed

6 files changed

+65
-17
lines changed

src/bipartite_graph.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ end
6363
function Base.setindex!(m::Matching{U}, v::Union{Integer, U}, i::Integer) where {U}
6464
if m.inv_match !== nothing
6565
oldv = m.match[i]
66+
# TODO: maybe default Matching to always have an `inv_match`?
67+
68+
# To maintain the invariant that `m.inv_match[m.match[i]] == i`, we need
69+
# to unassign the matching at `m.inv_match[v]` if it exists.
70+
if v isa Int && (iv = m.inv_match[v]) isa Int
71+
m.match[iv] = unassigned
72+
end
6673
if isa(oldv, Int)
6774
@assert m.inv_match[oldv] == i
6875
m.inv_match[oldv] = unassigned

src/structural_transformation/pantelides.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,27 @@ function pantelides!(state::TransformationState, ag::Union{AliasGraph, Nothing}
8383
ecolor = falses(neqs)
8484
var_eq_matching = Matching(nvars)
8585
neqs′ = neqs
86-
nnonemptyeqs = count(eq -> !isempty(𝑠neighbors(graph, eq)), 1:neqs′)
86+
nnonemptyeqs = count(eq -> !isempty(𝑠neighbors(graph, eq)) && eq_to_diff[eq] === nothing,
87+
1:neqs′)
8788

8889
# Allow matching for the highest differentiated variable that
8990
# currently appears in an equation (or implicit equation in a side ag)
9091
varwhitelist = falses(nvars)
9192
for var in 1:nvars
92-
if var_to_diff[var] === nothing
93+
if var_to_diff[var] === nothing && !varwhitelist[var]
9394
while isempty(𝑑neighbors(graph, var)) && (ag === nothing || !haskey(ag, var))
9495
var′ = invview(var_to_diff)[var]
9596
var′ === nothing && break
9697
var = var′
9798
end
98-
if !isempty(𝑑neighbors(graph, var)) || (ag !== nothing && haskey(ag, var))
99-
varwhitelist[var] = true
99+
if !isempty(𝑑neighbors(graph, var))
100+
if ag !== nothing && haskey(ag, var)
101+
# TODO: remove lower diff vars from whitelist
102+
c, a = ag[var]
103+
iszero(c) || (varwhitelist[a] = true)
104+
else
105+
varwhitelist[var] = true
106+
end
100107
end
101108
end
102109
end
@@ -107,6 +114,7 @@ function pantelides!(state::TransformationState, ag::Union{AliasGraph, Nothing}
107114

108115
for k in 1:neqs′
109116
eq′ = k
117+
eq_to_diff[eq′] === nothing || continue
110118
isempty(𝑠neighbors(graph, eq′)) && continue
111119
pathfound = false
112120
# In practice, `maxiters=8000` should never be reached, otherwise, the

src/structural_transformation/partial_state_selection.jl

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ function pss_graph_modia!(structure::SystemStructure, var_eq_matching, varlevel,
9898
end
9999
end
100100
for var in 1:ndsts(graph)
101-
if varlevel[var] !== 0 && var_eq_matching[var] === unassigned
101+
dv = var_to_diff[var]
102+
# If `var` is not algebraic (not differentiated nor a dummy derivative),
103+
# then it's a SelectedState
104+
if !(dv === nothing || (varlevel[dv] !== 0 && var_eq_matching[dv] === unassigned))
102105
var_eq_matching[var] = SelectedState()
103106
end
104107
end
@@ -154,11 +157,13 @@ function partial_state_selection_graph!(structure::SystemStructure, var_eq_match
154157
var_eq_matching
155158
end
156159

157-
function dummy_derivative_graph!(state::TransformationState, jac = nothing; kwargs...)
160+
function dummy_derivative_graph!(state::TransformationState, jac = nothing,
161+
(ag, diff_va) = (nothing, nothing);
162+
kwargs...)
158163
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
159-
var_eq_matching = complete(pantelides!(state))
164+
var_eq_matching = complete(pantelides!(state, ag))
160165
complete!(state.structure)
161-
dummy_derivative_graph!(state.structure, var_eq_matching, jac)
166+
dummy_derivative_graph!(state.structure, var_eq_matching, jac, (ag, diff_va))
162167
end
163168

164169
function compute_diff_level(diff_to_x)
@@ -178,7 +183,8 @@ function compute_diff_level(diff_to_x)
178183
return xlevel, maxlevel
179184
end
180185

181-
function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, jac)
186+
function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, jac,
187+
(ag, diff_va))
182188
@unpack eq_to_diff, var_to_diff, graph = structure
183189
diff_to_eq = invview(eq_to_diff)
184190
diff_to_var = invview(var_to_diff)
@@ -242,6 +248,18 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
242248
vars = [diff_to_var[var] for var in vars if diff_to_var[var] !== nothing]
243249
end
244250
end
251+
if diff_va !== nothing
252+
n_dummys = length(dummy_derivatives)
253+
needed = count(x -> x isa Int, diff_to_eq) - n_dummys
254+
n = 0
255+
for v in diff_va
256+
c, a = ag[v]
257+
n += 1
258+
push!(dummy_derivatives, iszero(c) ? v : a)
259+
needed == n && break
260+
continue
261+
end
262+
end
245263

246264
dummy_derivatives_set = BitSet(dummy_derivatives)
247265
# We can eliminate variables that are not a selected state (differential
@@ -251,6 +269,9 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
251269
dummy_derivatives_set = dummy_derivatives_set
252270

253271
v -> begin
272+
if ag !== nothing
273+
haskey(ag, v) && return false
274+
end
254275
dv = var_to_diff[v]
255276
dv === nothing || dv in dummy_derivatives_set
256277
end
@@ -266,7 +287,8 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
266287
Union{Unassigned, SelectedState};
267288
varfilter = can_eliminate)
268289
for v in eachindex(var_eq_matching)
269-
can_eliminate(v) && continue
290+
dv = var_to_diff[v]
291+
(dv === nothing || dv in dummy_derivatives_set) && continue
270292
var_eq_matching[v] = SelectedState()
271293
end
272294

src/systems/abstractsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ function namespace_assignment(eq::Assignment, sys)
405405
Assignment(_lhs, _rhs)
406406
end
407407

408-
function namespace_expr(O, sys, n = nameof(sys)) where {T}
408+
function namespace_expr(O, sys, n = nameof(sys))
409409
ivs = independent_variables(sys)
410410
O = unwrap(O)
411411
if any(isequal(O), ivs)

src/systems/alias_elimination.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ const KEEP = typemin(Int)
66

77
function alias_eliminate_graph!(state::TransformationState)
88
mm = linear_subsys_adjmat(state)
9-
size(mm, 1) == 0 && return AliasGraph(ndsts(state.structure.graph)), mm, BitSet() # No linear subsystems
9+
if size(mm, 1) == 0
10+
ag = AliasGraph(ndsts(state.structure.graph))
11+
return ag, mm, ag, mm, BitSet() # No linear subsystems
12+
end
1013

1114
@unpack graph, var_to_diff = state.structure
1215

@@ -39,7 +42,7 @@ alias_elimination(sys) = alias_elimination!(TearingState(sys; quick_cancel = tru
3942
function alias_elimination!(state::TearingState)
4043
sys = state.sys
4144
complete!(state.structure)
42-
ag, mm, updated_diff_vars = alias_eliminate_graph!(state)
45+
ag, mm, complete_ag, complete_mm, updated_diff_vars = alias_eliminate_graph!(state)
4346
isempty(ag) && return sys
4447

4548
fullvars = state.fullvars
@@ -543,6 +546,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
543546
#
544547
nvars = ndsts(graph)
545548
ag = AliasGraph(nvars)
549+
complete_ag = AliasGraph(nvars)
546550
mm, echelon_mm = simple_aliases!(ag, graph, var_to_diff, mm_orig)
547551

548552
# Step 3: Handle differentiated variables
@@ -702,6 +706,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
702706
while (iv = diff_to_var[v]) in zero_vars_set
703707
v = iv
704708
end
709+
complete_ag[v] = 0
705710
if diff_to_var[v] === nothing # `v` is reducible
706711
dag[v] = 0
707712
end
@@ -729,6 +734,13 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
729734
# Step 4: Merge dag and ag
730735
removed_aliases = BitSet()
731736
merged_ag = AliasGraph(nvars)
737+
for (v, (c, a)) in dag
738+
complete_ag[v] = c => a
739+
end
740+
for (v, (c, a)) in ag
741+
(processed[v] || (!iszero(a) && processed[a])) && continue
742+
complete_ag[v] = c => a
743+
end
732744
for (v, (c, a)) in dag
733745
# D(x) ~ D(y) cannot be removed if x and y are not aliases
734746
if v != a && !iszero(a)
@@ -789,7 +801,8 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
789801
update_graph_neighbors!(graph, ag)
790802
end
791803

792-
return ag, mm, updated_diff_vars
804+
complete_mm = reduce!(copy(echelon_mm), mm_orig, complete_ag, size(echelon_mm, 1))
805+
return ag, mm, complete_ag, complete_mm, updated_diff_vars
793806
end
794807

795808
function update_graph_neighbors!(graph, ag)

test/structural_transformation/index_reduction.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,7 @@ sol = solve(prob_auto, Rodas5());
119119
#plot(sol, vars=(D(x), y))
120120

121121
let pss_pendulum2 = partial_state_selection(pendulum2)
122-
# This currently selects `T` rather than `x` at top level. Needs tearing priorities to fix.
123-
@test length(equations(pss_pendulum2)) == 4
124-
@test length(equations(ModelingToolkit.ode_order_lowering(pss_pendulum2))) == 4
122+
@test length(equations(pss_pendulum2)) <= 6
125123
end
126124

127125
eqs = [D(x) ~ w,

0 commit comments

Comments
 (0)