@@ -11,17 +11,15 @@ function alias_eliminate_graph!(state::TransformationState; kwargs...)
11
11
end
12
12
13
13
@unpack graph, var_to_diff, solvable_graph = state. structure
14
- ag, mm, complete_ag, complete_mm, updated_diff_vars = alias_eliminate_graph! (complete (graph),
15
- complete (var_to_diff),
16
- mm)
14
+ ag, mm, complete_ag, complete_mm = alias_eliminate_graph! (state, mm)
17
15
if solvable_graph != = nothing
18
16
for (ei, e) in enumerate (mm. nzrows)
19
17
set_neighbors! (solvable_graph, e, mm. row_cols[ei])
20
18
end
21
19
update_graph_neighbors! (solvable_graph, ag)
22
20
end
23
21
24
- return ag, mm, complete_ag, complete_mm, updated_diff_vars
22
+ return ag, mm, complete_ag, complete_mm
25
23
end
26
24
27
25
# For debug purposes
@@ -51,23 +49,12 @@ function alias_elimination!(state::TearingState; kwargs...)
51
49
sys = state. sys
52
50
complete! (state. structure)
53
51
graph_orig = copy (state. structure. graph)
54
- ag, mm, complete_ag, complete_mm, updated_diff_vars = alias_eliminate_graph! (state;
55
- kwargs... )
52
+ ag, mm, complete_ag, complete_mm = alias_eliminate_graph! (state; kwargs... )
56
53
isempty (ag) && return sys, ag
57
54
58
55
fullvars = state. fullvars
59
56
@unpack var_to_diff, graph, solvable_graph = state. structure
60
57
61
- if ! isempty (updated_diff_vars)
62
- has_iv (sys) ||
63
- error (InvalidSystemException (" The system has no independent variable!" ))
64
- D = Differential (get_iv (sys))
65
- for v in updated_diff_vars
66
- dv = var_to_diff[v]
67
- fullvars[dv] = D (fullvars[v])
68
- end
69
- end
70
-
71
58
subs = Dict ()
72
59
obs = Equation[]
73
60
# If we encounter y = -D(x), then we need to expand the derivative when
@@ -328,6 +315,9 @@ function Base.setindex!(ag::AliasGraph, ::Nothing, i::Integer)
328
315
end
329
316
function Base. setindex! (ag:: AliasGraph , v:: Integer , i:: Integer )
330
317
@assert v == 0
318
+ if i > length (ag. aliasto)
319
+ resize! (ag. aliasto, i)
320
+ end
331
321
if ag. aliasto[i] === nothing
332
322
push! (ag. eliminated, i)
333
323
end
@@ -343,6 +333,9 @@ function Base.setindex!(ag::AliasGraph, p::Union{Pair{Int, Int}, Tuple{Int, Int}
343
333
return p
344
334
end
345
335
@assert v != 0 && c in (- 1 , 1 )
336
+ if i > length (ag. aliasto)
337
+ resize! (ag. aliasto, i)
338
+ end
346
339
if ag. aliasto[i] === nothing
347
340
push! (ag. eliminated, i)
348
341
end
@@ -379,6 +372,7 @@ function Graphs.add_edge!(g::WeightedGraph, u, v, w)
379
372
r && (g. dict[canonicalize (u, v)] = w)
380
373
r
381
374
end
375
+ Graphs. add_vertex! (g:: WeightedGraph ) = add_vertex! (g. graph)
382
376
Graphs. has_edge (g:: WeightedGraph , u, v) = has_edge (g. graph, u, v)
383
377
Graphs. ne (g:: WeightedGraph ) = ne (g. graph)
384
378
Graphs. nv (g:: WeightedGraph ) = nv (g. graph)
@@ -656,7 +650,20 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig)
656
650
return mm, echelon_mm
657
651
end
658
652
659
- function alias_eliminate_graph! (graph, var_to_diff, mm_orig:: SparseMatrixCLIL )
653
+ function var_derivative_here! (state, processed, g, eqg, dls, diff_var)
654
+ newvar = var_derivative! (state, diff_var)
655
+ @assert newvar == length (processed) + 1
656
+ push! (processed, true )
657
+ add_vertex! (g)
658
+ add_vertex! (eqg)
659
+ add_edge! (g, diff_var, newvar)
660
+ add_edge! (g, newvar, diff_var)
661
+ push! (dls. dists, typemax (Int))
662
+ return newvar
663
+ end
664
+
665
+ function alias_eliminate_graph! (state:: TransformationState , mm_orig:: SparseMatrixCLIL )
666
+ @unpack graph, var_to_diff = state. structure
660
667
# Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
661
668
# subsystem of the system we're interested in.
662
669
#
@@ -689,11 +696,12 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
689
696
# with a tie breaking strategy, the root variable (in this case `z`) is
690
697
# always uniquely determined. Thus, the result is well-defined.
691
698
dag = AliasGraph (nvars) # alias graph for differentiated variables
692
- updated_diff_vars = Int[]
693
699
diff_to_var = invview (var_to_diff)
694
700
processed = falses (nvars)
695
701
g, eqg, zero_vars = equality_diff_graph (ag, var_to_diff)
696
702
dls = DiffLevelState (g, var_to_diff)
703
+ original_nvars = length (var_to_diff)
704
+
697
705
is_diff_edge = let var_to_diff = var_to_diff
698
706
(v, w) -> var_to_diff[v] == w || var_to_diff[w] == v
699
707
end
@@ -709,10 +717,12 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
709
717
for _ in 1 : 10_000 # just to make sure that we don't stuck in an infinite loop
710
718
reach₌ = Pair{Int, Int}[]
711
719
# `r` is aliased to its equality aliases
712
- r === nothing || for n in neighbors (eqg, r)
713
- (n == r || is_diff_edge (r, n)) && continue
714
- c = get_weight (eqg, r, n)
715
- push! (reach₌, c => n)
720
+ if r != = nothing
721
+ for n in neighbors (eqg, r)
722
+ (n == r || is_diff_edge (r, n)) && continue
723
+ c = get_weight (eqg, r, n)
724
+ push! (reach₌, c => n)
725
+ end
716
726
end
717
727
# `r` is aliased to its previous differentiation level's aliases'
718
728
# derivative
@@ -733,19 +743,15 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
733
743
end
734
744
if r === nothing
735
745
isempty (reach₌) && break
736
- idx = findfirst (x -> x[1 ] == 1 , reach₌)
737
- if idx === nothing
738
- c, dr = reach₌[1 ]
739
- @assert c == - 1
740
- dag[dr] = (c, dr)
741
- else
742
- c, dr = reach₌[idx]
743
- @assert c == 1
746
+ let stem_set = stem_set
747
+ any (x -> x[2 ] in stem_set, reach₌) && break
744
748
end
745
- dr in stem_set && break
746
- var_to_diff[prev_r] = dr
747
- push! (updated_diff_vars, prev_r)
748
- prev_r = dr
749
+ # See example in the box above where D(D(D(z))) doesn't appear
750
+ # in the original system and needs to added, so we can alias to it.
751
+ # We do that here.
752
+ @assert prev_r != = - 1
753
+ prev_r = var_derivative_here! (state, processed, g, eqg, dls, prev_r)
754
+ r = nothing
749
755
else
750
756
prev_r = r
751
757
r = var_to_diff[r]
@@ -806,15 +812,18 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
806
812
for i in 1 : (length (stem) - 1 )
807
813
r = stem[i]
808
814
for dr in @view stem[(i + 1 ): end ]
815
+ # We cannot reduce newly introduced variables like `D(D(D(z)))`
816
+ # in the example box above.
817
+ dr > original_nvars && continue
809
818
if has_edge (eqg, r, dr)
810
819
c = get_weight (eqg, r, dr)
811
820
dag[dr] = c => r
812
821
end
813
822
end
814
823
end
815
824
# If a non-differentiated variable equals to 0, then we can eliminate
816
- # the whole differentiation chain. Otherwise, we can have to keep the
817
- # lowest differentiate variable in the differentiation chain.
825
+ # the whole differentiation chain. Otherwise, we will still have to keep
826
+ # the lowest differentiated variable in the differentiation chain.
818
827
# E.g.
819
828
# ```
820
829
# D(x) ~ 0
@@ -854,6 +863,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
854
863
end
855
864
# reducible after v
856
865
while (v = var_to_diff[v]) != = nothing
866
+ complete_ag[v] = 0
857
867
dag[v] = 0
858
868
end
859
869
end
@@ -900,6 +910,8 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
900
910
merged_ag[v] = c => a
901
911
end
902
912
ag = merged_ag
913
+ @set! echelon_mm. ncols = length (var_to_diff)
914
+ @set! mm_orig. ncols = length (var_to_diff)
903
915
mm = reduce! (copy (echelon_mm), mm_orig, ag, size (echelon_mm, 1 ))
904
916
905
917
# Step 5: Reflect our update decisions back into the graph, and make sure
@@ -940,7 +952,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
940
952
end
941
953
942
954
complete_mm = reduce! (copy (echelon_mm), mm_orig, complete_ag, size (echelon_mm, 1 ))
943
- return ag, mm, complete_ag, complete_mm, updated_diff_vars
955
+ return ag, mm, complete_ag, complete_mm
944
956
end
945
957
946
958
function update_graph_neighbors! (graph, ag)
0 commit comments