@@ -54,23 +54,16 @@ function alias_elimination!(state::TearingState)
54
54
end
55
55
56
56
subs = Dict ()
57
+ obs = Equation[]
57
58
# If we encounter y = -D(x), then we need to expand the derivative when
58
59
# D(y) appears in the equation, so that D(-D(x)) becomes -D(D(x)).
59
60
to_expand = Int[]
60
61
diff_to_var = invview (var_to_diff)
61
- # TODO /FIXME : this also needs to be computed recursively because we need to
62
- # follow the alias graph like `a => b => c` and make sure that the final
63
- # graph always contains the destination.
64
- extra_eqs = Equation[]
65
- extra_vars = BitSet ()
66
62
for (v, (coeff, alias)) in pairs (ag)
67
- if iszero (alias) || ! (isempty (𝑑neighbors (graph, alias)) && isempty (𝑑neighbors (graph, v)))
68
- subs[fullvars[v]] = iszero (coeff) ? 0 : coeff * fullvars[alias]
69
- else
70
- push! (extra_eqs, 0 ~ coeff * fullvars[alias] - fullvars[v])
71
- push! (extra_vars, v)
72
- # @show fullvars[v] => fullvars[alias]
73
- end
63
+ lhs = fullvars[v]
64
+ rhs = iszero (coeff) ? 0 : coeff * fullvars[alias]
65
+ subs[lhs] = rhs
66
+ v != alias && push! (obs, lhs ~ rhs)
74
67
if coeff == - 1
75
68
# if `alias` is like -D(x)
76
69
diff_to_var[alias] === nothing && continue
@@ -101,7 +94,7 @@ function alias_elimination!(state::TearingState)
101
94
idx = 0
102
95
cursor = 1
103
96
ndels = length (dels)
104
- for (i, e) in enumerate (old_to_new)
97
+ for i in eachindex (old_to_new)
105
98
if cursor <= ndels && i == dels[cursor]
106
99
cursor += 1
107
100
old_to_new[i] = - 1
@@ -111,7 +104,9 @@ function alias_elimination!(state::TearingState)
111
104
old_to_new[i] = idx
112
105
end
113
106
107
+ lineqs = BitSet (old_to_new[e] for e in mm. nzrows)
114
108
for (ieq, eq) in enumerate (eqs)
109
+ ieq in lineqs && continue
115
110
eqs[ieq] = substitute (eq, subs)
116
111
end
117
112
@@ -131,7 +126,7 @@ function alias_elimination!(state::TearingState)
131
126
sys = state. sys
132
127
@set! sys. eqs = eqs
133
128
@set! sys. states = newstates
134
- @set! sys. observed = [observed (sys); [lhs ~ rhs for (lhs, rhs) in pairs (subs)] ]
129
+ @set! sys. observed = [observed (sys); obs ]
135
130
return invalidate_cache! (sys)
136
131
end
137
132
@@ -212,7 +207,8 @@ function Base.getindex(ag::AliasGraph, i::Integer)
212
207
coeff, var = (sign (r), abs (r))
213
208
nc = coeff
214
209
av = var
215
- if var in keys (ag)
210
+ # We support `x -> -x` as an alias.
211
+ if var != i && var in keys (ag)
216
212
# Amortized lookup. Check if since we last looked this up, our alias was
217
213
# itself aliased. If so, just adjust the alias table.
218
214
ac, av = ag[var]
248
244
249
245
function Base. setindex! (ag:: AliasGraph , p:: Pair{Int, Int} , i:: Integer )
250
246
(c, v) = p
247
+ if c == 0 || v == 0
248
+ ag[i] = 0
249
+ return p
250
+ end
251
251
@assert v != 0 && c in (- 1 , 1 )
252
252
if ag. aliasto[i] === nothing
253
253
push! (ag. eliminated, i)
@@ -280,22 +280,27 @@ function reduce!(mm::SparseMatrixCLIL, ag::AliasGraph)
280
280
c = rs[j]
281
281
_alias = get (ag, c, nothing )
282
282
if _alias != = nothing
283
- push! (dels, j)
284
283
coeff, alias = _alias
285
- iszero (coeff) && (j += 1 ; continue )
286
- inc = coeff * rvals[j]
287
- i = searchsortedfirst (rs, alias)
288
- if i > length (rs) || rs[i] != alias
289
- # if we add a variable to what we already visited, make sure
290
- # to bump the cursor.
291
- j += i <= j
292
- for (i, e) in enumerate (dels)
293
- e >= i && (dels[i] += 1 )
294
- end
295
- insert! (rs, i, alias)
296
- insert! (rvals, i, inc)
284
+ if alias == c
285
+ i = searchsortedfirst (rs, alias)
286
+ rvals[i] *= coeff
297
287
else
298
- rvals[i] += inc
288
+ push! (dels, j)
289
+ iszero (coeff) && (j += 1 ; continue )
290
+ inc = coeff * rvals[j]
291
+ i = searchsortedfirst (rs, alias)
292
+ if i > length (rs) || rs[i] != alias
293
+ # if we add a variable to what we already visited, make sure
294
+ # to bump the cursor.
295
+ j += i <= j
296
+ for (i, e) in enumerate (dels)
297
+ e >= i && (dels[i] += 1 )
298
+ end
299
+ insert! (rs, i, alias)
300
+ insert! (rvals, i, inc)
301
+ else
302
+ rvals[i] += inc
303
+ end
299
304
end
300
305
end
301
306
j += 1
@@ -651,6 +656,7 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig, irreducibles = ())
651
656
ag[v] = 0
652
657
end
653
658
659
+ echelon_mm = copy (mm)
654
660
lss! = lss (mm, pivots, ag)
655
661
# Step 2.1: Go backwards, collecting eliminated variables and substituting
656
662
# alias as we go.
@@ -677,7 +683,7 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig, irreducibles = ())
677
683
reduced && while any (lss!, 1 : rank2)
678
684
end
679
685
680
- return mm
686
+ return mm, echelon_mm
681
687
end
682
688
683
689
function mark_processed! (processed, var_to_diff, v)
@@ -695,7 +701,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
695
701
#
696
702
nvars = ndsts (graph)
697
703
ag = AliasGraph (nvars)
698
- mm = simple_aliases! (ag, graph, var_to_diff, mm_orig)
704
+ mm, echelon_mm = simple_aliases! (ag, graph, var_to_diff, mm_orig)
699
705
state = Main. _state[]
700
706
fullvars = state. fullvars
701
707
for (v, (c, a)) in ag
@@ -718,14 +724,14 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
718
724
#
719
725
# where `-->` is an edge in `var_to_diff`, `⇒` is an edge in `ag`, and the
720
726
# part in the box are purely conceptual, i.e. `D(D(D(z)))` doesn't appear in
721
- # the system.
727
+ # the system. We call the variables in the box "virtual" variables.
722
728
#
723
729
# To finish the algorithm, we backtrack to the root differentiation chain.
724
730
# If the variable already exists in the chain, then we alias them
725
731
# (e.g. `x_t ⇒ D(D(z))`), else, we substitute and update `var_to_diff`.
726
732
#
727
733
# Note that since we always prefer the higher differentiated variable and
728
- # with a tie breaking strategy. The root variable (in this case `z`) is
734
+ # with a tie breaking strategy, the root variable (in this case `z`) is
729
735
# always uniquely determined. Thus, the result is well-defined.
730
736
diff_to_var = invview (var_to_diff)
731
737
invag = SimpleDiGraph (nvars)
@@ -735,10 +741,11 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
735
741
end
736
742
processed = falses (nvars)
737
743
iag = InducedAliasGraph (ag, invag, var_to_diff)
738
- newag = AliasGraph (nvars)
744
+ dag = AliasGraph (nvars) # alias graph for differentiated variables
739
745
newinvag = SimpleDiGraph (nvars)
740
- irreducibles = BitSet ()
746
+ removed_aliases = BitSet ()
741
747
updated_diff_vars = Int[]
748
+ irreducibles = Int[]
742
749
for (v, dv) in enumerate (var_to_diff)
743
750
processed[v] && continue
744
751
(dv === nothing && diff_to_var[v] === nothing ) && continue
@@ -753,19 +760,22 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
753
760
nlevels = length (level_to_var)
754
761
current_coeff_level = Ref ((0 , 0 ))
755
762
add_alias! = let current_coeff_level = current_coeff_level,
756
- level_to_var = level_to_var, newag = newag , newinvag = newinvag,
763
+ level_to_var = level_to_var, dag = dag , newinvag = newinvag,
757
764
processed = processed
758
765
759
766
v -> begin
760
767
coeff, level = current_coeff_level[]
761
768
if level + 1 <= length (level_to_var)
762
769
av = level_to_var[level + 1 ]
763
770
if v != av # if the level_to_var isn't from the root branch
764
- newag [v] = coeff => av
771
+ dag [v] = coeff => av
765
772
add_edge! (newinvag, av, v)
766
773
end
767
774
else
768
775
@assert length (level_to_var) == level
776
+ if coeff != 1
777
+ dag[v] = coeff => v
778
+ end
769
779
push! (level_to_var, v)
770
780
end
771
781
mark_processed! (processed, var_to_diff, v)
@@ -788,54 +798,38 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
788
798
current_coeff_level[] = coeff, lv
789
799
extreme_var (var_to_diff, v, nothing , Val (false ), callback = add_alias!)
790
800
end
791
- len = length (level_to_var)
792
- len > 1 || continue
793
801
794
- set_v_zero! = let newag = newag
795
- v -> newag[v] = 0
802
+ @show processed
803
+ len = length (level_to_var)
804
+ set_v_zero! = let dag = dag
805
+ v -> dag[v] = 0
796
806
end
807
+ zero_av_idx = 0
797
808
for (i, av) in enumerate (level_to_var)
798
- has_zero = false
809
+ has_zero = iszero (get (ag, av, (1 , 0 ))[1 ])
810
+ push! (removed_aliases, av)
799
811
for v in neighbors (newinvag, av)
800
- cv = get (ag, v, nothing )
801
- cv === nothing && continue
802
- c, v = cv
803
- iszero (c) || continue
804
- has_zero = true
805
- # if a chain starts to equal to zero, then all its descendants
806
- # must be zero and reducible
807
- if i < len
808
- # we have `x = 0`
809
- v = level_to_var[i + 1 ]
810
- extreme_var (var_to_diff, v, nothing , Val (false ), callback = set_v_zero!)
811
- end
812
- break
812
+ has_zero = has_zero || iszero (get (ag, v, (1 , 0 ))[1 ])
813
+ push! (removed_aliases, v)
813
814
end
814
- has_zero && break
815
-
816
- # all non-highest order differentiated variables are reducible.
817
- if i == len
818
- # if an irreducible alias appears in only one equation, then
819
- # it's actually not an alias, but a proper equation. E.g.
820
- # D(D(phi)) = a
821
- # D(phi) = sin(t)
822
- # `a` and `D(D(phi))` are not irreducible state. Hence, we need
823
- # to remove `av` from all alias graphs and mark those pairs
824
- # irreducible.
825
- push! (irreducibles, av)
826
- for v in neighbors (newinvag, av)
827
- newag[v] = nothing
828
- push! (irreducibles, v)
829
- end
830
- for v in neighbors (invag, av)
831
- newag[v] = nothing
832
- push! (irreducibles, v)
833
- end
834
- if (cv = get (ag, av, nothing )) != = nothing && ! iszero (cv[2 ])
835
- push! (irreducibles, cv[2 ])
836
- end
815
+ if zero_av_idx == 0 && has_zero
816
+ zero_av_idx = i
817
+ end
818
+ end
819
+ # If a chain starts to equal to zero, then all its derivatives must be
820
+ # zero. Irreducible variables are highest differentiated variables (with
821
+ # order >= 1) that are not zero.
822
+ if zero_av_idx > 0
823
+ extreme_var (var_to_diff, level_to_var[zero_av_idx], nothing , Val (false ), callback = set_v_zero!)
824
+ if zero_av_idx > 2
825
+ @warn " 1"
826
+ push! (irreducibles, level_to_var[zero_av_idx - 1 ])
837
827
end
828
+ elseif len >= 2
829
+ @warn " 2"
830
+ push! (irreducibles, level_to_var[len])
838
831
end
832
+ # Handle virtual variables
839
833
if nlevels < len
840
834
for i in (nlevels + 1 ): len
841
835
li = level_to_var[i]
@@ -844,20 +838,42 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
844
838
end
845
839
end
846
840
end
847
- for (v, (c, a)) in newag
848
- a = a == 0 ? 0 : c * fullvars[a]
849
- @info " differential aliases" fullvars[v] => a
850
- end
851
- @show fullvars[collect (irreducibles)]
852
841
853
- if ! isempty (irreducibles)
854
- ag = newag
855
- for k in keys (ag)
856
- push! (irreducibles, k)
842
+ # Merge dag and ag
843
+ freshag = AliasGraph (nvars)
844
+ @show irreducibles
845
+ @show dag
846
+ for (v, (c, a)) in dag
847
+ # TODO : make sure that `irreducibles` are
848
+ # D(x) ~ D(y) cannot be removed if x and y are not aliases
849
+ if v != a && a in irreducibles
850
+ push! (removed_aliases, v)
851
+ @goto NEXT_ITER
852
+ elseif v != a && ! iszero (a)
853
+ vv = v
854
+ aa = a
855
+ while true
856
+ vv′ = vv
857
+ vv = diff_to_var[vv]
858
+ vv === nothing && break
859
+ if ! (haskey (dag, vv) && dag[vv][2 ] == diff_to_var[aa])
860
+ push! (removed_aliases, vv′)
861
+ @goto NEXT_ITER
862
+ end
863
+ end
857
864
end
858
- mm_orig2 = isempty (ag) ? mm_orig : reduce! (copy (mm_orig), ag)
859
- mm = simple_aliases! (ag, graph, var_to_diff, mm_orig2, irreducibles)
865
+ freshag[v] = c => a
866
+ @label NEXT_ITER
867
+ end
868
+ for (v, (c, a)) in ag
869
+ v in removed_aliases && continue
870
+ freshag[v] = c => a
871
+ end
872
+ if freshag != ag
873
+ ag = freshag
874
+ mm = reduce! (copy (echelon_mm), ag)
860
875
end
876
+ @info " " echelon_mm mm
861
877
862
878
for (v, (c, a)) in ag
863
879
va = iszero (a) ? a : fullvars[a]
@@ -869,7 +885,6 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
869
885
set_neighbors! (graph, e, mm. row_cols[ei])
870
886
end
871
887
872
- # because of `irreducibles`, `mm` cannot always be trusted.
873
888
return ag, mm, updated_diff_vars
874
889
end
875
890
0 commit comments