@@ -497,9 +497,9 @@ count_nonzeros(a::AbstractArray) = count(!iszero, a)
497
497
# Here we have a guarantee that they won't, so we can make this identification
498
498
count_nonzeros (a:: SparseVector ) = nnz (a)
499
499
500
- # Linear variables are variables that only appear in linear equations with only
501
- # linear variables. Also, if a variable's any derivaitves is nonlinear, then all
502
- # of them are not linear variables.
500
+ # Linear variables are highest order differentiated variables that only appear
501
+ # in linear equations with only linear variables. Also, if a variable's any
502
+ # derivaitves is nonlinear, then all of them are not linear variables.
503
503
function find_linear_variables (graph, linear_equations, var_to_diff, irreducibles)
504
504
stack = Int[]
505
505
linear_variables = falses (length (var_to_diff))
@@ -541,8 +541,8 @@ function find_linear_variables(graph, linear_equations, var_to_diff, irreducible
541
541
islinear || continue
542
542
lv = extreme_var (var_to_diff, v)
543
543
oldlv = lv
544
- remove = false
545
- while true
544
+ remove = invview (var_to_diff)[v] != = nothing
545
+ while ! remove
546
546
for eq in 𝑑neighbors (graph, lv)
547
547
if ! (eq in linear_equations_set)
548
548
remove = true
@@ -788,37 +788,51 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
788
788
current_coeff_level[] = coeff, lv
789
789
extreme_var (var_to_diff, v, nothing , Val (false ), callback = add_alias!)
790
790
end
791
- max_lv > 0 || continue
791
+ len = length (level_to_var)
792
+ len > 1 || continue
792
793
793
794
set_v_zero! = let newag = newag
794
795
v -> newag[v] = 0
795
796
end
796
- len = length (level_to_var)
797
- for (i, v) in enumerate (level_to_var)
798
- _alias = get (ag, v, nothing )
799
-
800
- # if a chain starts to equal to zero, then all its descendants must
801
- # be zero and reducible
802
- if _alias != = nothing && iszero (_alias[1 ])
797
+ for (i, av) in enumerate (level_to_var)
798
+ has_zero = false
799
+ for v in neighbors (newinvag, av)
800
+ cv = get (ag, v, nothing )
801
+ cv === nothing && continue
802
+ c, v = cv
803
+ iszero (c) || continue
804
+ has_zero = true
805
+ # if a chain starts to equal to zero, then all its descendants
806
+ # must be zero and reducible
803
807
if i < len
804
808
# we have `x = 0`
805
809
v = level_to_var[i + 1 ]
806
810
extreme_var (var_to_diff, v, nothing , Val (false ), callback = set_v_zero!)
807
811
end
808
812
break
809
813
end
814
+ has_zero && break
810
815
811
816
# all non-highest order differentiated variables are reducible.
812
817
if i == len
813
818
# if an irreducible alias appears in only one equation, then
814
819
# it's actually not an alias, but a proper equation. E.g.
815
820
# D(D(phi)) = a
816
821
# D(phi) = sin(t)
817
- # `a` and `D(D(phi))` are not irreducible state.
818
- push! (irreducibles, v)
819
- for av in neighbors (newinvag, v)
820
- newag[av] = nothing
821
- push! (irreducibles, av)
822
+ # `a` and `D(D(phi))` are not irreducible state. Hence, we need
823
+ # to remove `av` from all alias graphs and mark those pairs
824
+ # irreducible.
825
+ push! (irreducibles, av)
826
+ for v in neighbors (newinvag, av)
827
+ newag[v] = nothing
828
+ push! (irreducibles, v)
829
+ end
830
+ for v in neighbors (invag, av)
831
+ newag[v] = nothing
832
+ push! (irreducibles, v)
833
+ end
834
+ if (cv = get (ag, av, nothing )) != = nothing && ! iszero (cv[2 ])
835
+ push! (irreducibles, cv[2 ])
822
836
end
823
837
end
824
838
end
0 commit comments