Skip to content

Commit 39fe5a8

Browse files
authored
Merge pull request #1701 from SciML/myb/alias
Simplify alias elimination logic
2 parents 0f17abf + 3ea322d commit 39fe5a8

File tree

2 files changed

+26
-32
lines changed

2 files changed

+26
-32
lines changed

src/systems/alias_elimination.jl

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -450,20 +450,28 @@ count_nonzeros(a::AbstractArray) = count(!iszero, a)
450450
# Here we have a guarantee that they won't, so we can make this identification
451451
count_nonzeros(a::SparseVector) = nnz(a)
452452

453-
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL, only_algebraic = true,
454-
irreducibles = ())
453+
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL, irreducibles = ())
455454
mm = copy(mm_orig)
456455
is_linear_equations = falses(size(AsSubMatrix(mm_orig), 1))
457456
for e in mm_orig.nzrows
458457
is_linear_equations[e] = true
459458
end
460459

461-
is_not_potential_state = isnothing.(var_to_diff)
460+
# If linear highest differentiated variables cannot be assigned to a pivot,
461+
# then we can set it to zero. We use `rank1` to track this.
462+
#
463+
# We only use alias graph to record reducible variables. We use `rank2` to
464+
# track this.
465+
#
466+
# For all the other variables, we can update the original system with
467+
# Bareiss'ed coefficients as Gaussian elimination is nullspace perserving
468+
# and we are only working on linear homogeneous subsystem.
469+
is_linear_variables = isnothing.(var_to_diff)
470+
is_reducible = trues(length(var_to_diff))
462471
for v in irreducibles
463-
is_not_potential_state[v] = false
472+
is_linear_variables[v] = false
473+
is_reducible[v] = false
464474
end
465-
is_linear_variables = only_algebraic ? copy(is_not_potential_state) :
466-
is_not_potential_state
467475
for i in 𝑠vertices(graph)
468476
is_linear_equations[i] && continue
469477
for j in 𝑠neighbors(graph, i)
@@ -481,12 +489,10 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL, only_algebr
481489
r !== nothing && return r
482490
rank1 = k - 1
483491
end
484-
if only_algebraic
485-
if rank2 === nothing
486-
r = find_masked_pivot(is_not_potential_state, M, k)
487-
r !== nothing && return r
488-
rank2 = k - 1
489-
end
492+
if rank2 === nothing
493+
r = find_masked_pivot(is_reducible, M, k)
494+
r !== nothing && return r
495+
rank2 = k - 1
490496
end
491497
# TODO: It would be better to sort the variables by
492498
# derivative order here to enable more elimination
@@ -524,25 +530,9 @@ function lss(mm, pivots, ag)
524530
end
525531
end
526532

527-
function simple_aliases!(ag, graph, var_to_diff, mm_orig, only_algebraic, irreducibles = ())
528-
# Let `m = the number of linear equations` and `n = the number of
529-
# variables`.
530-
#
531-
# `do_bareiss` conceptually gives us this system:
532-
# rank1 | [ M₁₁ M₁₂ | M₁₃ ] [v₁] = [0]
533-
# rank2 | [ 0 M₂₂ | M₂₃ ] P [v₂] = [0]
534-
# -------------------|-------------------
535-
# [ 0 0 | 0 ] [v₃] = [0]
536-
537-
# Where `v₁` are the purely linear algebraic variables (i.e. those that only
538-
# appear in linear algebraic equations), `v₂` are the variables that may be
539-
# potentially solved by the linear system, and `v₃` are the variables that
540-
# contribute to the equations, but are not solved by the linear system. Note
541-
# that the complete system may be larger than the linear subsystem and
542-
# include variables that do not appear here.
533+
function simple_aliases!(ag, graph, var_to_diff, mm_orig, irreducibles = ())
543534
mm, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss!(graph, var_to_diff,
544535
mm_orig,
545-
only_algebraic,
546536
irreducibles)
547537

548538
# Step 2: Simplify the system using the Bareiss factorization
@@ -585,7 +575,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
585575
#
586576
nvars = ndsts(graph)
587577
ag = AliasGraph(nvars)
588-
mm = simple_aliases!(ag, graph, var_to_diff, mm_orig, false)
578+
mm = simple_aliases!(ag, graph, var_to_diff, mm_orig)
589579

590580
# Step 3: Handle differentiated variables
591581
# At this point, `var_to_diff` and `ag` form a tree structure like the
@@ -694,7 +684,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
694684
if !isempty(irreducibles)
695685
ag = newag
696686
mm_orig2 = isempty(ag) ? mm_orig : reduce!(copy(mm_orig), ag)
697-
mm = simple_aliases!(ag, graph, var_to_diff, mm_orig2, true, irreducibles)
687+
mm = simple_aliases!(ag, graph, var_to_diff, mm_orig2, irreducibles)
698688
end
699689

700690
# for (v, (c, a)) in ag

test/components.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ function check_contract(sys)
1919
end
2020

2121
function check_rc_sol(sol)
22+
rpi = sol[rc_model.resistor.p.i]
23+
@test any(!isequal(rpi[1]), rpi) # test that we don't have a constant system
2224
@test sol[rc_model.resistor.p.i] == sol[resistor.p.i] == sol[capacitor.p.i]
2325
@test sol[rc_model.resistor.n.i] == sol[resistor.n.i] == -sol[capacitor.p.i]
2426
@test sol[rc_model.capacitor.n.i] == sol[capacitor.n.i] == -sol[capacitor.p.i]
@@ -31,8 +33,9 @@ end
3133
include("../examples/rc_model.jl")
3234

3335
@test ModelingToolkit.n_extra_equations(capacitor) == 2
34-
@test length(equations(structural_simplify(rc_model, allow_parameter = false))) > 1
36+
@test length(equations(structural_simplify(rc_model, allow_parameter = false))) == 2
3537
sys = structural_simplify(rc_model)
38+
@test length(equations(sys)) == 1
3639
check_contract(sys)
3740
@test !isempty(ModelingToolkit.defaults(sys))
3841
u0 = [capacitor.v => 0.0
@@ -141,6 +144,7 @@ sol = solve(prob, Tsit5())
141144

142145
include("../examples/serial_inductor.jl")
143146
sys = structural_simplify(ll_model)
147+
@test length(equations(sys)) == 2
144148
check_contract(sys)
145149
u0 = states(sys) .=> 0
146150
@test_nowarn ODEProblem(sys, u0, (0, 10.0))

0 commit comments

Comments
 (0)