Skip to content

Commit d950da5

Browse files
committed
WIP: Try to avoid violating AlisGraph invariants
In f0ec298, AliasGraph gained cases where x is aliased to `-x`. This happens when there a side branch of the differentiation tree that was deeper than the stem. Currently alias elimination attempts to mutate-in-place the var-to-diff relationships in order to move the side branch onto the main stem. This is problematic for two reasons: 1. It assumes that TransformationState is mutable and that var_to_diff relationships may be arbitrarily updated. This is not necessarily the case. 2. It abuses the AliasGraph data structure. Ordinarily an x -> -x alias would force these to be equal, i.e. imply `x->0`. However, here it is used as an in-place marker of this branch transplant. This attempts to fix both of these issues by allowing alias_elimination to introduce extra derivative variables on the main stem (similar to pantelides introducing extra derivative variables).
1 parent 48ac4f3 commit d950da5

File tree

3 files changed

+41
-36
lines changed

3 files changed

+41
-36
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ function parameters end
122122
# this has to be included early to deal with depency issues
123123
include("structural_transformation/bareiss.jl")
124124
function complete end
125+
function var_derivative! end
126+
function var_derivative_graph! end
125127
include("bipartite_graph.jl")
126128
using .BipartiteGraphs
127129

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
2727

2828
using ModelingToolkit.BipartiteGraphs
2929
import .BipartiteGraphs: invview, complete
30+
import ModelingToolkit: var_derivative!, var_derivative_graph!
3031
using Graphs
3132
using ModelingToolkit.SystemStructures
3233
using ModelingToolkit.SystemStructures: algeqs, EquationsView

src/systems/alias_elimination.jl

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,15 @@ function alias_eliminate_graph!(state::TransformationState; kwargs...)
1111
end
1212

1313
@unpack graph, var_to_diff, solvable_graph = state.structure
14-
ag, mm, complete_ag, complete_mm, updated_diff_vars = alias_eliminate_graph!(complete(graph),
15-
complete(var_to_diff),
16-
mm)
14+
ag, mm, complete_ag, complete_mm = alias_eliminate_graph!(state, mm)
1715
if solvable_graph !== nothing
1816
for (ei, e) in enumerate(mm.nzrows)
1917
set_neighbors!(solvable_graph, e, mm.row_cols[ei])
2018
end
2119
update_graph_neighbors!(solvable_graph, ag)
2220
end
2321

24-
return ag, mm, complete_ag, complete_mm, updated_diff_vars
22+
return ag, mm, complete_ag, complete_mm
2523
end
2624

2725
# For debug purposes
@@ -51,23 +49,12 @@ function alias_elimination!(state::TearingState; kwargs...)
5149
sys = state.sys
5250
complete!(state.structure)
5351
graph_orig = copy(state.structure.graph)
54-
ag, mm, complete_ag, complete_mm, updated_diff_vars = alias_eliminate_graph!(state;
55-
kwargs...)
52+
ag, mm, complete_ag, complete_mm = alias_eliminate_graph!(state; kwargs...)
5653
isempty(ag) && return sys, ag
5754

5855
fullvars = state.fullvars
5956
@unpack var_to_diff, graph, solvable_graph = state.structure
6057

61-
if !isempty(updated_diff_vars)
62-
has_iv(sys) ||
63-
error(InvalidSystemException("The system has no independent variable!"))
64-
D = Differential(get_iv(sys))
65-
for v in updated_diff_vars
66-
dv = var_to_diff[v]
67-
fullvars[dv] = D(fullvars[v])
68-
end
69-
end
70-
7158
subs = Dict()
7259
obs = Equation[]
7360
# If we encounter y = -D(x), then we need to expand the derivative when
@@ -328,6 +315,9 @@ function Base.setindex!(ag::AliasGraph, ::Nothing, i::Integer)
328315
end
329316
function Base.setindex!(ag::AliasGraph, v::Integer, i::Integer)
330317
@assert v == 0
318+
if i > length(ag.aliasto)
319+
resize!(ag.aliasto, i)
320+
end
331321
if ag.aliasto[i] === nothing
332322
push!(ag.eliminated, i)
333323
end
@@ -343,6 +333,9 @@ function Base.setindex!(ag::AliasGraph, p::Union{Pair{Int, Int}, Tuple{Int, Int}
343333
return p
344334
end
345335
@assert v != 0 && c in (-1, 1)
336+
if i > length(ag.aliasto)
337+
resize!(ag.aliasto, i)
338+
end
346339
if ag.aliasto[i] === nothing
347340
push!(ag.eliminated, i)
348341
end
@@ -379,6 +372,7 @@ function Graphs.add_edge!(g::WeightedGraph, u, v, w)
379372
r && (g.dict[canonicalize(u, v)] = w)
380373
r
381374
end
375+
Graphs.add_vertex!(g::WeightedGraph) = add_vertex!(g.graph)
382376
Graphs.has_edge(g::WeightedGraph, u, v) = has_edge(g.graph, u, v)
383377
Graphs.ne(g::WeightedGraph) = ne(g.graph)
384378
Graphs.nv(g::WeightedGraph) = nv(g.graph)
@@ -656,7 +650,8 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig)
656650
return mm, echelon_mm
657651
end
658652

659-
function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
653+
function alias_eliminate_graph!(state::TransformationState, mm_orig::SparseMatrixCLIL)
654+
@unpack graph, var_to_diff = state.structure
660655
# Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
661656
# subsystem of the system we're interested in.
662657
#
@@ -689,11 +684,23 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
689684
# with a tie breaking strategy, the root variable (in this case `z`) is
690685
# always uniquely determined. Thus, the result is well-defined.
691686
dag = AliasGraph(nvars) # alias graph for differentiated variables
692-
updated_diff_vars = Int[]
693687
diff_to_var = invview(var_to_diff)
694688
processed = falses(nvars)
695689
g, eqg, zero_vars = equality_diff_graph(ag, var_to_diff)
696690
dls = DiffLevelState(g, var_to_diff)
691+
692+
function var_drivative_here!(diff_var)
693+
newvar = var_derivative!(state, diff_var)
694+
@assert newvar == length(processed)+1
695+
push!(processed, true)
696+
add_vertex!(g)
697+
add_vertex!(eqg)
698+
add_edge!(g, diff_var, newvar)
699+
add_edge!(g, newvar, diff_var)
700+
push!(dls.dists, typemax(Int))
701+
return newvar
702+
end
703+
697704
is_diff_edge = let var_to_diff = var_to_diff
698705
(v, w) -> var_to_diff[v] == w || var_to_diff[w] == v
699706
end
@@ -709,10 +716,12 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
709716
for _ in 1:10_000 # just to make sure that we don't stuck in an infinite loop
710717
reach₌ = Pair{Int, Int}[]
711718
# `r` is aliased to its equality aliases
712-
r === nothing || for n in neighbors(eqg, r)
713-
(n == r || is_diff_edge(r, n)) && continue
714-
c = get_weight(eqg, r, n)
715-
push!(reach₌, c => n)
719+
if r !== nothing
720+
for n in neighbors(eqg, r)
721+
(n == r || is_diff_edge(r, n)) && continue
722+
c = get_weight(eqg, r, n)
723+
push!(reach₌, c => n)
724+
end
716725
end
717726
# `r` is aliased to its previous differentiation level's aliases'
718727
# derivative
@@ -733,19 +742,12 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
733742
end
734743
if r === nothing
735744
isempty(reach₌) && break
736-
idx = findfirst(x -> x[1] == 1, reach₌)
737-
if idx === nothing
738-
c, dr = reach₌[1]
739-
@assert c == -1
740-
dag[dr] = (c, dr)
741-
else
742-
c, dr = reach₌[idx]
743-
@assert c == 1
744-
end
745-
dr in stem_set && break
746-
var_to_diff[prev_r] = dr
747-
push!(updated_diff_vars, prev_r)
748-
prev_r = dr
745+
# See example in the box above where D(D(D(z))) doesn't appear
746+
# in the original system and needs to added, so we can alias to it.
747+
# We do that here.
748+
@assert prev_r !== -1
749+
prev_r = var_drivative_here!(prev_r)
750+
r = nothing
749751
else
750752
prev_r = r
751753
r = var_to_diff[r]
@@ -940,7 +942,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
940942
end
941943

942944
complete_mm = reduce!(copy(echelon_mm), mm_orig, complete_ag, size(echelon_mm, 1))
943-
return ag, mm, complete_ag, complete_mm, updated_diff_vars
945+
return ag, mm, complete_ag, complete_mm
944946
end
945947

946948
function update_graph_neighbors!(graph, ag)

0 commit comments

Comments
 (0)