Skip to content

Commit b68ba20

Browse files
committed
Non-zero highest order differentiated variables are irreducible
1 parent bd143a5 commit b68ba20

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

src/systems/alias_elimination.jl

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,6 @@ count_nonzeros(a::SparseVector) = nnz(a)
501501
# linear variables. Also, if a variable's any derivaitves is nonlinear, then all
502502
# of them are not linear variables.
503503
function find_linear_variables(graph, linear_equations, var_to_diff, irreducibles)
504-
fullvars = Main._state[].fullvars
505504
stack = Int[]
506505
linear_variables = falses(length(var_to_diff))
507506
var_to_lineq = Dict{Int, BitSet}()
@@ -516,7 +515,6 @@ function find_linear_variables(graph, linear_equations, var_to_diff, irreducible
516515
eqs === nothing && continue
517516
for eq in eqs, v′ in 𝑠neighbors(graph, eq)
518517
if linear_variables[v′]
519-
@show v′, fullvars[v′]
520518
linear_variables[v′] = false
521519
push!(stack, v′)
522520
end
@@ -738,6 +736,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
738736
processed = falses(nvars)
739737
iag = InducedAliasGraph(ag, invag, var_to_diff)
740738
newag = AliasGraph(nvars)
739+
newinvag = SimpleDiGraph(nvars)
741740
irreducibles = BitSet()
742741
updated_diff_vars = Int[]
743742
for (v, dv) in enumerate(var_to_diff)
@@ -754,14 +753,16 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
754753
nlevels = length(level_to_var)
755754
current_coeff_level = Ref((0, 0))
756755
add_alias! = let current_coeff_level = current_coeff_level,
757-
level_to_var = level_to_var, newag = newag, processed = processed
756+
level_to_var = level_to_var, newag = newag, newinvag = newinvag,
757+
processed = processed
758758

759759
v -> begin
760760
coeff, level = current_coeff_level[]
761761
if level + 1 <= length(level_to_var)
762762
av = level_to_var[level + 1]
763763
if v != av # if the level_to_var isn't from the root branch
764764
newag[v] = coeff => av
765+
add_edge!(newinvag, av, v)
765766
end
766767
else
767768
@assert length(level_to_var) == level
@@ -792,44 +793,48 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
792793
set_v_zero! = let newag = newag
793794
v -> newag[v] = 0
794795
end
796+
len = length(level_to_var)
795797
for (i, v) in enumerate(level_to_var)
796798
_alias = get(ag, v, nothing)
797799

798800
# if a chain starts to equal to zero, then all its descendants must
799801
# be zero and reducible
800802
if _alias !== nothing && iszero(_alias[1])
801-
if i < length(level_to_var)
803+
if i < len
802804
# we have `x = 0`
803805
v = level_to_var[i + 1]
804806
extreme_var(var_to_diff, v, nothing, Val(false), callback = set_v_zero!)
805807
end
806808
break
807809
end
808810

809-
v_eqs = 𝑑neighbors(graph, v)
810-
# if an irreducible appears in only one equation, we need to make
811-
# sure that the other variables don't get eliminated
812-
if length(v_eqs) == 1
813-
eq = v_eqs[1]
814-
for av in 𝑠neighbors(graph, eq)
811+
# all non-highest order differentiated variables are reducible.
812+
if i == len
813+
# if an irreducible alias appears in only one equation, then
814+
# it's actually not an alias, but a proper equation. E.g.
815+
# D(D(phi)) = a
816+
# D(phi) = sin(t)
817+
# `a` and `D(D(phi))` are not irreducible state.
818+
push!(irreducibles, v)
819+
for av in neighbors(newinvag, v)
820+
newag[av] = nothing
815821
push!(irreducibles, av)
816822
end
817-
ag[v] = nothing
818823
end
819-
push!(irreducibles, v)
820824
end
821-
if nlevels < (new_nlevels = length(level_to_var))
822-
for i in (nlevels + 1):new_nlevels
825+
if nlevels < len
826+
for i in (nlevels + 1):len
823827
li = level_to_var[i]
824828
var_to_diff[level_to_var[i - 1]] = li
825829
push!(updated_diff_vars, level_to_var[i - 1])
826830
end
827831
end
828832
end
829-
for (v, (c, a)) in ag
833+
for (v, (c, a)) in newag
830834
a = a == 0 ? 0 : c * fullvars[a]
831835
@info "differential aliases" fullvars[v] => a
832836
end
837+
@show fullvars[collect(irreducibles)]
833838

834839
if !isempty(irreducibles)
835840
ag = newag

0 commit comments

Comments
 (0)