@@ -11,17 +11,15 @@ function alias_eliminate_graph!(state::TransformationState; kwargs...)
11
11
end
12
12
13
13
@unpack graph, var_to_diff, solvable_graph = state. structure
14
- ag, mm, complete_ag, complete_mm, updated_diff_vars = alias_eliminate_graph! (complete (graph),
15
- complete (var_to_diff),
16
- mm)
14
+ ag, mm, complete_ag, complete_mm = alias_eliminate_graph! (state, mm)
17
15
if solvable_graph != = nothing
18
16
for (ei, e) in enumerate (mm. nzrows)
19
17
set_neighbors! (solvable_graph, e, mm. row_cols[ei])
20
18
end
21
19
update_graph_neighbors! (solvable_graph, ag)
22
20
end
23
21
24
- return ag, mm, complete_ag, complete_mm, updated_diff_vars
22
+ return ag, mm, complete_ag, complete_mm
25
23
end
26
24
27
25
# For debug purposes
@@ -51,23 +49,12 @@ function alias_elimination!(state::TearingState; kwargs...)
51
49
sys = state. sys
52
50
complete! (state. structure)
53
51
graph_orig = copy (state. structure. graph)
54
- ag, mm, complete_ag, complete_mm, updated_diff_vars = alias_eliminate_graph! (state;
55
- kwargs... )
52
+ ag, mm, complete_ag, complete_mm = alias_eliminate_graph! (state; kwargs... )
56
53
isempty (ag) && return sys, ag
57
54
58
55
fullvars = state. fullvars
59
56
@unpack var_to_diff, graph, solvable_graph = state. structure
60
57
61
- if ! isempty (updated_diff_vars)
62
- has_iv (sys) ||
63
- error (InvalidSystemException (" The system has no independent variable!" ))
64
- D = Differential (get_iv (sys))
65
- for v in updated_diff_vars
66
- dv = var_to_diff[v]
67
- fullvars[dv] = D (fullvars[v])
68
- end
69
- end
70
-
71
58
subs = Dict ()
72
59
obs = Equation[]
73
60
# If we encounter y = -D(x), then we need to expand the derivative when
@@ -328,6 +315,9 @@ function Base.setindex!(ag::AliasGraph, ::Nothing, i::Integer)
328
315
end
329
316
function Base. setindex! (ag:: AliasGraph , v:: Integer , i:: Integer )
330
317
@assert v == 0
318
+ if i > length (ag. aliasto)
319
+ resize! (ag. aliasto, i)
320
+ end
331
321
if ag. aliasto[i] === nothing
332
322
push! (ag. eliminated, i)
333
323
end
@@ -343,6 +333,9 @@ function Base.setindex!(ag::AliasGraph, p::Union{Pair{Int, Int}, Tuple{Int, Int}
343
333
return p
344
334
end
345
335
@assert v != 0 && c in (- 1 , 1 )
336
+ if i > length (ag. aliasto)
337
+ resize! (ag. aliasto, i)
338
+ end
346
339
if ag. aliasto[i] === nothing
347
340
push! (ag. eliminated, i)
348
341
end
@@ -379,6 +372,7 @@ function Graphs.add_edge!(g::WeightedGraph, u, v, w)
379
372
r && (g. dict[canonicalize (u, v)] = w)
380
373
r
381
374
end
375
+ Graphs. add_vertex! (g:: WeightedGraph ) = add_vertex! (g. graph)
382
376
Graphs. has_edge (g:: WeightedGraph , u, v) = has_edge (g. graph, u, v)
383
377
Graphs. ne (g:: WeightedGraph ) = ne (g. graph)
384
378
Graphs. nv (g:: WeightedGraph ) = nv (g. graph)
@@ -656,7 +650,8 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig)
656
650
return mm, echelon_mm
657
651
end
658
652
659
- function alias_eliminate_graph! (graph, var_to_diff, mm_orig:: SparseMatrixCLIL )
653
+ function alias_eliminate_graph! (state:: TransformationState , mm_orig:: SparseMatrixCLIL )
654
+ @unpack graph, var_to_diff = state. structure
660
655
# Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
661
656
# subsystem of the system we're interested in.
662
657
#
@@ -689,11 +684,23 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
689
684
# with a tie breaking strategy, the root variable (in this case `z`) is
690
685
# always uniquely determined. Thus, the result is well-defined.
691
686
dag = AliasGraph (nvars) # alias graph for differentiated variables
692
- updated_diff_vars = Int[]
693
687
diff_to_var = invview (var_to_diff)
694
688
processed = falses (nvars)
695
689
g, eqg, zero_vars = equality_diff_graph (ag, var_to_diff)
696
690
dls = DiffLevelState (g, var_to_diff)
691
+
692
+ function var_drivative_here! (diff_var)
693
+ newvar = var_derivative! (state, diff_var)
694
+ @assert newvar == length (processed)+ 1
695
+ push! (processed, true )
696
+ add_vertex! (g)
697
+ add_vertex! (eqg)
698
+ add_edge! (g, diff_var, newvar)
699
+ add_edge! (g, newvar, diff_var)
700
+ push! (dls. dists, typemax (Int))
701
+ return newvar
702
+ end
703
+
697
704
is_diff_edge = let var_to_diff = var_to_diff
698
705
(v, w) -> var_to_diff[v] == w || var_to_diff[w] == v
699
706
end
@@ -709,10 +716,12 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
709
716
for _ in 1 : 10_000 # just to make sure that we don't stuck in an infinite loop
710
717
reach₌ = Pair{Int, Int}[]
711
718
# `r` is aliased to its equality aliases
712
- r === nothing || for n in neighbors (eqg, r)
713
- (n == r || is_diff_edge (r, n)) && continue
714
- c = get_weight (eqg, r, n)
715
- push! (reach₌, c => n)
719
+ if r != = nothing
720
+ for n in neighbors (eqg, r)
721
+ (n == r || is_diff_edge (r, n)) && continue
722
+ c = get_weight (eqg, r, n)
723
+ push! (reach₌, c => n)
724
+ end
716
725
end
717
726
# `r` is aliased to its previous differentiation level's aliases'
718
727
# derivative
@@ -733,19 +742,12 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
733
742
end
734
743
if r === nothing
735
744
isempty (reach₌) && break
736
- idx = findfirst (x -> x[1 ] == 1 , reach₌)
737
- if idx === nothing
738
- c, dr = reach₌[1 ]
739
- @assert c == - 1
740
- dag[dr] = (c, dr)
741
- else
742
- c, dr = reach₌[idx]
743
- @assert c == 1
744
- end
745
- dr in stem_set && break
746
- var_to_diff[prev_r] = dr
747
- push! (updated_diff_vars, prev_r)
748
- prev_r = dr
745
+ # See example in the box above where D(D(D(z))) doesn't appear
746
+ # in the original system and needs to added, so we can alias to it.
747
+ # We do that here.
748
+ @assert prev_r != = - 1
749
+ prev_r = var_drivative_here! (prev_r)
750
+ r = nothing
749
751
else
750
752
prev_r = r
751
753
r = var_to_diff[r]
@@ -940,7 +942,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
940
942
end
941
943
942
944
complete_mm = reduce! (copy (echelon_mm), mm_orig, complete_ag, size (echelon_mm, 1 ))
943
- return ag, mm, complete_ag, complete_mm, updated_diff_vars
945
+ return ag, mm, complete_ag, complete_mm
944
946
end
945
947
946
948
function update_graph_neighbors! (graph, ag)
0 commit comments