Skip to content

Commit d63043a

Browse files
committed
Do not eliminate the newly introduced variable & resize matrix
1 parent 5d54a03 commit d63043a

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

src/systems/alias_elimination.jl

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,18 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig)
650650
return mm, echelon_mm
651651
end
652652

653+
function var_derivative_here!(state, processed, g, eqg, dls, diff_var)
654+
newvar = var_derivative!(state, diff_var)
655+
@assert newvar == length(processed) + 1
656+
push!(processed, true)
657+
add_vertex!(g)
658+
add_vertex!(eqg)
659+
add_edge!(g, diff_var, newvar)
660+
add_edge!(g, newvar, diff_var)
661+
push!(dls.dists, typemax(Int))
662+
return newvar
663+
end
664+
653665
function alias_eliminate_graph!(state::TransformationState, mm_orig::SparseMatrixCLIL)
654666
@unpack graph, var_to_diff = state.structure
655667
# Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
@@ -688,18 +700,7 @@ function alias_eliminate_graph!(state::TransformationState, mm_orig::SparseMatri
688700
processed = falses(nvars)
689701
g, eqg, zero_vars = equality_diff_graph(ag, var_to_diff)
690702
dls = DiffLevelState(g, var_to_diff)
691-
692-
function var_drivative_here!(diff_var)
693-
newvar = var_derivative!(state, diff_var)
694-
@assert newvar == length(processed)+1
695-
push!(processed, true)
696-
add_vertex!(g)
697-
add_vertex!(eqg)
698-
add_edge!(g, diff_var, newvar)
699-
add_edge!(g, newvar, diff_var)
700-
push!(dls.dists, typemax(Int))
701-
return newvar
702-
end
703+
original_nvars = length(var_to_diff)
703704

704705
is_diff_edge = let var_to_diff = var_to_diff
705706
(v, w) -> var_to_diff[v] == w || var_to_diff[w] == v
@@ -746,7 +747,7 @@ function alias_eliminate_graph!(state::TransformationState, mm_orig::SparseMatri
746747
# in the original system and needs to added, so we can alias to it.
747748
# We do that here.
748749
@assert prev_r !== -1
749-
prev_r = var_drivative_here!(prev_r)
750+
prev_r = var_derivative_here!(state, processed, g, eqg, dls, prev_r)
750751
r = nothing
751752
else
752753
prev_r = r
@@ -808,15 +809,18 @@ function alias_eliminate_graph!(state::TransformationState, mm_orig::SparseMatri
808809
for i in 1:(length(stem) - 1)
809810
r = stem[i]
810811
for dr in @view stem[(i + 1):end]
812+
# We cannot reduce newly introduced variables like `D(D(D(z)))`
813+
# in the example box above.
814+
dr > original_nvars && continue
811815
if has_edge(eqg, r, dr)
812816
c = get_weight(eqg, r, dr)
813817
dag[dr] = c => r
814818
end
815819
end
816820
end
817821
# If a non-differentiated variable equals to 0, then we can eliminate
818-
# the whole differentiation chain. Otherwise, we can have to keep the
819-
# lowest differentiate variable in the differentiation chain.
822+
# the whole differentiation chain. Otherwise, we will still have to keep
823+
# the lowest differentiated variable in the differentiation chain.
820824
# E.g.
821825
# ```
822826
# D(x) ~ 0
@@ -856,6 +860,7 @@ function alias_eliminate_graph!(state::TransformationState, mm_orig::SparseMatri
856860
end
857861
# reducible after v
858862
while (v = var_to_diff[v]) !== nothing
863+
complete_ag[v] = 0
859864
dag[v] = 0
860865
end
861866
end
@@ -902,7 +907,8 @@ function alias_eliminate_graph!(state::TransformationState, mm_orig::SparseMatri
902907
merged_ag[v] = c => a
903908
end
904909
ag = merged_ag
905-
mm = reduce!(copy(echelon_mm), mm_orig, ag, size(echelon_mm, 1))
910+
echelon_mm = resize_cols(echelon_mm, length(var_to_diff))
911+
mm = reduce!(echelon_mm, mm_orig, ag, size(echelon_mm, 1))
906912

907913
# Step 5: Reflect our update decisions back into the graph, and make sure
908914
# that the RHS of observable variables are defined.

src/systems/sparsematrixclil.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ function SparseMatrixCLIL(mm::AbstractMatrix)
4141
SparseMatrixCLIL(nrows, ncols, Int[1:length(row_cols);], row_cols, row_vals)
4242
end
4343

44+
function resize_cols(mm::SparseMatrixCLIL, nc)
45+
@set! mm.ncols = nc
46+
copy(mm)
47+
end
48+
4449
struct CLILVector{T, Ti} <: AbstractSparseVector{T, Ti}
4550
vec::SparseVector{T, Ti}
4651
end

0 commit comments

Comments
 (0)