Skip to content

Commit 9271b4d

Browse files
authored
Merge pull request #1975 from Keno/kf/aliasadv
Try to avoid violating AliasGraph invariants
2 parents 779d547 + fd5c157 commit 9271b4d

File tree

4 files changed

+53
-38
lines changed

4 files changed

+53
-38
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ function parameters end
122122
# this has to be included early to deal with depency issues
123123
include("structural_transformation/bareiss.jl")
124124
function complete end
125+
function var_derivative! end
126+
function var_derivative_graph! end
125127
include("bipartite_graph.jl")
126128
using .BipartiteGraphs
127129

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
2727

2828
using ModelingToolkit.BipartiteGraphs
2929
import .BipartiteGraphs: invview, complete
30+
import ModelingToolkit: var_derivative!, var_derivative_graph!
3031
using Graphs
3132
using ModelingToolkit.SystemStructures
3233
using ModelingToolkit.SystemStructures: algeqs, EquationsView

src/systems/alias_elimination.jl

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,15 @@ function alias_eliminate_graph!(state::TransformationState; kwargs...)
1111
end
1212

1313
@unpack graph, var_to_diff, solvable_graph = state.structure
14-
ag, mm, complete_ag, complete_mm, updated_diff_vars = alias_eliminate_graph!(complete(graph),
15-
complete(var_to_diff),
16-
mm)
14+
ag, mm, complete_ag, complete_mm = alias_eliminate_graph!(state, mm)
1715
if solvable_graph !== nothing
1816
for (ei, e) in enumerate(mm.nzrows)
1917
set_neighbors!(solvable_graph, e, mm.row_cols[ei])
2018
end
2119
update_graph_neighbors!(solvable_graph, ag)
2220
end
2321

24-
return ag, mm, complete_ag, complete_mm, updated_diff_vars
22+
return ag, mm, complete_ag, complete_mm
2523
end
2624

2725
# For debug purposes
@@ -51,23 +49,12 @@ function alias_elimination!(state::TearingState; kwargs...)
5149
sys = state.sys
5250
complete!(state.structure)
5351
graph_orig = copy(state.structure.graph)
54-
ag, mm, complete_ag, complete_mm, updated_diff_vars = alias_eliminate_graph!(state;
55-
kwargs...)
52+
ag, mm, complete_ag, complete_mm = alias_eliminate_graph!(state; kwargs...)
5653
isempty(ag) && return sys, ag
5754

5855
fullvars = state.fullvars
5956
@unpack var_to_diff, graph, solvable_graph = state.structure
6057

61-
if !isempty(updated_diff_vars)
62-
has_iv(sys) ||
63-
error(InvalidSystemException("The system has no independent variable!"))
64-
D = Differential(get_iv(sys))
65-
for v in updated_diff_vars
66-
dv = var_to_diff[v]
67-
fullvars[dv] = D(fullvars[v])
68-
end
69-
end
70-
7158
subs = Dict()
7259
obs = Equation[]
7360
# If we encounter y = -D(x), then we need to expand the derivative when
@@ -328,6 +315,9 @@ function Base.setindex!(ag::AliasGraph, ::Nothing, i::Integer)
328315
end
329316
function Base.setindex!(ag::AliasGraph, v::Integer, i::Integer)
330317
@assert v == 0
318+
if i > length(ag.aliasto)
319+
resize!(ag.aliasto, i)
320+
end
331321
if ag.aliasto[i] === nothing
332322
push!(ag.eliminated, i)
333323
end
@@ -343,6 +333,9 @@ function Base.setindex!(ag::AliasGraph, p::Union{Pair{Int, Int}, Tuple{Int, Int}
343333
return p
344334
end
345335
@assert v != 0 && c in (-1, 1)
336+
if i > length(ag.aliasto)
337+
resize!(ag.aliasto, i)
338+
end
346339
if ag.aliasto[i] === nothing
347340
push!(ag.eliminated, i)
348341
end
@@ -379,6 +372,7 @@ function Graphs.add_edge!(g::WeightedGraph, u, v, w)
379372
r && (g.dict[canonicalize(u, v)] = w)
380373
r
381374
end
375+
Graphs.add_vertex!(g::WeightedGraph) = add_vertex!(g.graph)
382376
Graphs.has_edge(g::WeightedGraph, u, v) = has_edge(g.graph, u, v)
383377
Graphs.ne(g::WeightedGraph) = ne(g.graph)
384378
Graphs.nv(g::WeightedGraph) = nv(g.graph)
@@ -656,7 +650,20 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig)
656650
return mm, echelon_mm
657651
end
658652

659-
function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
653+
function var_derivative_here!(state, processed, g, eqg, dls, diff_var)
654+
newvar = var_derivative!(state, diff_var)
655+
@assert newvar == length(processed) + 1
656+
push!(processed, true)
657+
add_vertex!(g)
658+
add_vertex!(eqg)
659+
add_edge!(g, diff_var, newvar)
660+
add_edge!(g, newvar, diff_var)
661+
push!(dls.dists, typemax(Int))
662+
return newvar
663+
end
664+
665+
function alias_eliminate_graph!(state::TransformationState, mm_orig::SparseMatrixCLIL)
666+
@unpack graph, var_to_diff = state.structure
660667
# Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
661668
# subsystem of the system we're interested in.
662669
#
@@ -689,11 +696,12 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
689696
# with a tie breaking strategy, the root variable (in this case `z`) is
690697
# always uniquely determined. Thus, the result is well-defined.
691698
dag = AliasGraph(nvars) # alias graph for differentiated variables
692-
updated_diff_vars = Int[]
693699
diff_to_var = invview(var_to_diff)
694700
processed = falses(nvars)
695701
g, eqg, zero_vars = equality_diff_graph(ag, var_to_diff)
696702
dls = DiffLevelState(g, var_to_diff)
703+
original_nvars = length(var_to_diff)
704+
697705
is_diff_edge = let var_to_diff = var_to_diff
698706
(v, w) -> var_to_diff[v] == w || var_to_diff[w] == v
699707
end
@@ -709,10 +717,12 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
709717
for _ in 1:10_000 # just to make sure that we don't stuck in an infinite loop
710718
reach₌ = Pair{Int, Int}[]
711719
# `r` is aliased to its equality aliases
712-
r === nothing || for n in neighbors(eqg, r)
713-
(n == r || is_diff_edge(r, n)) && continue
714-
c = get_weight(eqg, r, n)
715-
push!(reach₌, c => n)
720+
if r !== nothing
721+
for n in neighbors(eqg, r)
722+
(n == r || is_diff_edge(r, n)) && continue
723+
c = get_weight(eqg, r, n)
724+
push!(reach₌, c => n)
725+
end
716726
end
717727
# `r` is aliased to its previous differentiation level's aliases'
718728
# derivative
@@ -733,19 +743,15 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
733743
end
734744
if r === nothing
735745
isempty(reach₌) && break
736-
idx = findfirst(x -> x[1] == 1, reach₌)
737-
if idx === nothing
738-
c, dr = reach₌[1]
739-
@assert c == -1
740-
dag[dr] = (c, dr)
741-
else
742-
c, dr = reach₌[idx]
743-
@assert c == 1
746+
let stem_set = stem_set
747+
any(x -> x[2] in stem_set, reach₌) && break
744748
end
745-
dr in stem_set && break
746-
var_to_diff[prev_r] = dr
747-
push!(updated_diff_vars, prev_r)
748-
prev_r = dr
749+
# See example in the box above where D(D(D(z))) doesn't appear
750+
# in the original system and needs to added, so we can alias to it.
751+
# We do that here.
752+
@assert prev_r !== -1
753+
prev_r = var_derivative_here!(state, processed, g, eqg, dls, prev_r)
754+
r = nothing
749755
else
750756
prev_r = r
751757
r = var_to_diff[r]
@@ -806,15 +812,18 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
806812
for i in 1:(length(stem) - 1)
807813
r = stem[i]
808814
for dr in @view stem[(i + 1):end]
815+
# We cannot reduce newly introduced variables like `D(D(D(z)))`
816+
# in the example box above.
817+
dr > original_nvars && continue
809818
if has_edge(eqg, r, dr)
810819
c = get_weight(eqg, r, dr)
811820
dag[dr] = c => r
812821
end
813822
end
814823
end
815824
# If a non-differentiated variable equals to 0, then we can eliminate
816-
# the whole differentiation chain. Otherwise, we can have to keep the
817-
# lowest differentiate variable in the differentiation chain.
825+
# the whole differentiation chain. Otherwise, we will still have to keep
826+
# the lowest differentiated variable in the differentiation chain.
818827
# E.g.
819828
# ```
820829
# D(x) ~ 0
@@ -854,6 +863,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
854863
end
855864
# reducible after v
856865
while (v = var_to_diff[v]) !== nothing
866+
complete_ag[v] = 0
857867
dag[v] = 0
858868
end
859869
end
@@ -900,6 +910,8 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
900910
merged_ag[v] = c => a
901911
end
902912
ag = merged_ag
913+
@set! echelon_mm.ncols = length(var_to_diff)
914+
@set! mm_orig.ncols = length(var_to_diff)
903915
mm = reduce!(copy(echelon_mm), mm_orig, ag, size(echelon_mm, 1))
904916

905917
# Step 5: Reflect our update decisions back into the graph, and make sure
@@ -940,7 +952,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
940952
end
941953

942954
complete_mm = reduce!(copy(echelon_mm), mm_orig, complete_ag, size(echelon_mm, 1))
943-
return ag, mm, complete_ag, complete_mm, updated_diff_vars
955+
return ag, mm, complete_ag, complete_mm
944956
end
945957

946958
function update_graph_neighbors!(graph, ag)

test/reduction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ eqs = [a ~ D(w)
256256
@named sys = ODESystem(eqs, t, vars, [])
257257
ss = alias_elimination(sys)
258258
@test equations(ss) == [0 ~ D(D(phi)) - a, 0 ~ sin(t) - D(phi)]
259-
@test observed(ss) == [w ~ D(phi)]
259+
@test observed(ss) == [w ~ D(phi), D(w) ~ D(D(phi))]
260260

261261
@variables t x(t) y(t)
262262
D = Differential(t)

0 commit comments

Comments
 (0)