Skip to content

Commit 2140705

Browse files
committed
Handle zero variables and fix reduce!
1 parent 204b6d4 commit 2140705

File tree

2 files changed

+119
-66
lines changed

2 files changed

+119
-66
lines changed

src/systems/abstractsystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,7 @@ function structural_simplify(sys::AbstractSystem; simplify = false, kwargs...)
955955
state = inputs_to_parameters!(state)
956956
sys = state.sys
957957
check_consistency(state)
958+
find_solvables!(state; kwargs...)
958959
sys = dummy_derivative(sys, state)
959960
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
960961
@set! sys.observed = topsort_equations(observed(sys), fullstates)

src/systems/alias_elimination.jl

Lines changed: 118 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ function alias_elimination(sys; debug = false)
5050
end
5151
end
5252

53+
subs = Dict()
54+
for (v, (coeff, alias)) in pairs(ag)
55+
subs[fullvars[v]] = iszero(coeff) ? 0 : coeff * fullvars[alias]
56+
end
57+
5358
dels = Int[]
5459
eqs = collect(equations(state))
5560
for (ei, e) in enumerate(mm.nzrows)
@@ -67,10 +72,6 @@ function alias_elimination(sys; debug = false)
6772
end
6873
deleteat!(eqs, sort!(dels))
6974

70-
subs = Dict()
71-
for (v, (coeff, alias)) in pairs(ag)
72-
subs[fullvars[v]] = iszero(coeff) ? 0 : coeff * fullvars[alias]
73-
end
7475
for (ieq, eq) in enumerate(eqs)
7576
eqs[ieq] = substitute(eq, subs)
7677
end
@@ -229,22 +230,29 @@ end
229230
function reduce!(mm::SparseMatrixCLIL, ag::AliasGraph)
230231
dels = Int[]
231232
for (i, rs) in enumerate(mm.row_cols)
233+
p = i == 7
232234
rvals = mm.row_vals[i]
233-
for (j, c) in enumerate(rs)
235+
j = 1
236+
while j <= length(rs)
237+
c = rs[j]
234238
_alias = get(ag, c, nothing)
235239
if _alias !== nothing
236240
push!(dels, j)
237241
coeff, alias = _alias
238-
iszero(coeff) && continue
242+
iszero(coeff) && (j += 1; continue)
239243
inc = coeff * rvals[j]
240244
i = searchsortedfirst(rs, alias)
241-
if i > length(rvals)
242-
push!(rs, alias)
243-
push!(rvals, inc)
245+
if i > length(rs) || rs[i] != alias
246+
if i <= j
247+
j += 1
248+
end
249+
insert!(rs, i, alias)
250+
insert!(rvals, i, inc)
244251
else
245252
rvals[i] += inc
246253
end
247254
end
255+
j += 1
248256
end
249257
deleteat!(rs, dels)
250258
deleteat!(rvals, dels)
@@ -304,8 +312,9 @@ function Base.iterate(it::IAGNeighbors, state = nothing)
304312
end
305313
else
306314
used_ag = true
307-
if (_n = get(ag, v, nothing)) !== nothing
308-
n = _n[2]
315+
# We don't care about the alising value because we only use this to
316+
# find the root of the tree.
317+
if (_n = get(ag, v, nothing)) !== nothing && (n = _n[2]) > 0
309318
if !visited[n]
310319
n, lv = extreme_var(var_to_diff, n, level)
311320
extreme_var(var_to_diff, n, nothing, Val(false), callback = callback!)
@@ -366,7 +375,7 @@ function Base.iterate(it::StatefulAliasBFS, queue = (eltype(it)[(1, 0, it.t)]))
366375
coeff, lv, t = popfirst!(queue)
367376
nextlv = lv + 1
368377
for (coeff′, c) in children(t)
369-
# TODO: maybe fix the children iterator instead.
378+
# FIXME: use the visited cache!
370379
# A "cycle" might occur when we have `x ~ D(x)`.
371380
nodevalue(c) == t.root && continue
372381
# -1 <= coeff <= 1
@@ -383,7 +392,7 @@ function Base.iterate(c::RootedAliasChildren, s = nothing)
383392
rat = c.t
384393
@unpack iag, root = rat
385394
@unpack ag, invag, var_to_diff, visited = iag
386-
(root = var_to_diff[root]) === nothing && return nothing
395+
(iszero(root) || (root = var_to_diff[root]) === nothing) && return nothing
387396
if s === nothing
388397
stage = 1
389398
it = iterate(neighbors(invag, root))
@@ -404,7 +413,7 @@ function Base.iterate(c::RootedAliasChildren, s = nothing)
404413
it === nothing && return nothing
405414
e, ns = it
406415
# c * a = b <=> a = c * b when -1 <= c <= 1
407-
return (ag[e], RootedAliasTree(iag, e)), (stage, iterate(invag, ns))
416+
return (ag[e][1], RootedAliasTree(iag, e)), (stage, iterate(it, ns))
408417
end
409418

410419
count_nonzeros(a::AbstractArray) = count(!iszero, a)
@@ -413,22 +422,21 @@ count_nonzeros(a::AbstractArray) = count(!iszero, a)
413422
# Here we have a guarantee that they won't, so we can make this identification
414423
count_nonzeros(a::SparseVector) = nnz(a)
415424

416-
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
425+
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL, only_linear_algebraic = false, irreducibles = ())
417426
mm = copy(mm_orig)
418427
is_linear_equations = falses(size(AsSubMatrix(mm_orig), 1))
428+
diff_to_var = invview(var_to_diff)
429+
islowest = let diff_to_var = diff_to_var
430+
v -> diff_to_var[v] === nothing
431+
end
419432
for e in mm_orig.nzrows
420-
is_linear_equations[e] = true
433+
is_linear_equations[e] = all(islowest, 𝑠neighbors(graph, e))
421434
end
422435

423-
# Variables that are highest order differentiated cannot be states of an ODE
424-
is_not_potential_state = isnothing.(var_to_diff)
425-
is_linear_variables = copy(is_not_potential_state)
426-
for i in 𝑠vertices(graph)
427-
is_linear_equations[i] && continue
428-
for j in 𝑠neighbors(graph, i)
429-
is_linear_variables[j] = false
430-
end
436+
var_to_eq = let is_linear_equations = is_linear_equations, islowest = islowest, irreducibles = irreducibles
437+
maximal_matching(graph, eq -> is_linear_equations[eq], var -> islowest(var) && !(var in irreducibles))
431438
end
439+
is_linear_variables = isa.(var_to_eq, Int)
432440
solvable_variables = findall(is_linear_variables)
433441

434442
function do_bareiss!(M, Mold = nothing)
@@ -440,11 +448,10 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
440448
r !== nothing && return r
441449
rank1 = k - 1
442450
end
443-
if rank2 === nothing
444-
r = find_masked_pivot(is_not_potential_state, M, k)
445-
r !== nothing && return r
446-
rank2 = k - 1
447-
end
451+
only_linear_algebraic && return nothing
452+
# TODO: It would be better to sort the variables by
453+
# derivative order here to enable more elimination
454+
# opportunities.
448455
return find_masked_pivot(nothing, M, k)
449456
end
450457
function find_and_record_pivot(M, k)
@@ -459,46 +466,58 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
459466
end
460467
bareiss_ops = ((M, i, j) -> nothing, myswaprows!,
461468
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
462-
rank3, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
463-
rank1 = something(rank1, rank3)
464-
rank2 = something(rank2, rank3)
465-
(rank1, rank2, rank3, pivots)
469+
rank2, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
470+
rank1 = something(rank1, rank2)
471+
(rank1, rank2, pivots)
466472
end
467473

468474
return mm, solvable_variables, do_bareiss!(mm, mm_orig)
469475
end
470476

471-
function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
472-
debug = false)
473-
# Step 1: Perform bareiss factorization on the adjacency matrix of the linear
474-
# subsystem of the system we're interested in.
475-
#
477+
# Kind of like the backward substitution, but we don't actually rely on it
478+
# being lower triangular. We eliminate a variable if there are at most 2
479+
# variables left after the substitution.
480+
function lss(mm, pivots, ag)
481+
ei -> let mm = mm, pivots = pivots, ag = ag
482+
vi = pivots[ei]
483+
locally_structure_simplify!((@view mm[ei, :]), vi, ag)
484+
end
485+
end
486+
487+
function simple_aliases!(ag, graph, var_to_diff, mm_orig, only_linear_algebraic = false, irreducibles = ())
476488
# Let `m = the number of linear equations` and `n = the number of
477489
# variables`.
478490
#
479491
# `do_bareiss` conceptually gives us this system:
480-
# rank1 | [ M₁₁ M₁₂ | M₁₃ M₁₄ ] [v₁] = [0]
481-
# rank2 | [ 0 M₂₂ | M₂₃ M₂₄ ] P [v₂] = [0]
482-
# -------------------|------------------------
483-
# rank3 | [ 0 0 | M₃₃ M₃₄ ] [v₃] = [0]
484-
# [ 0 0 | 0 0 ] [v₄] = [0]
485-
mm, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss!(graph, var_to_diff,
486-
mm_orig)
492+
# rank1 | [ M₁₁ M₁₂ | M₁₃ ] [v₁] = [0]
493+
# rank2 | [ 0 M₂₂ | M₂₃ ] P [v₂] = [0]
494+
# -------------------|-------------------
495+
# [ 0 0 | 0 ] [v₃] = [0]
496+
497+
# Where `v₁` are the purely linear algebraic variables (i.e. those that only
498+
# appear in linear algebraic equations), `v₂` are the variables that may be
499+
# potentially solved by the linear system, and `v₃` are the variables that
500+
# contribute to the equations, but are not solved by the linear system. Note
501+
# that the complete system may be larger than the linear subsystem and
502+
# include variables that do not appear here.
503+
mm, solvable_variables, (rank1, rank2, pivots) = aag_bareiss!(graph, var_to_diff,
504+
mm_orig, only_linear_algebraic, irreducibles)
487505

488506
# Step 2: Simplify the system using the Bareiss factorization
489-
ag = AliasGraph(size(mm, 2))
490-
for v in setdiff(solvable_variables, @view pivots[1:rank1])
491-
ag[v] = 0
492-
end
493-
494-
# Kind of like the backward substitution, but we don't actually rely on it
495-
# being lower triangular. We eliminate a variable if there are at most 2
496-
# variables left after the substitution.
497-
function lss!(ei::Integer)
498-
vi = pivots[ei]
499-
locally_structure_simplify!((@view mm[ei, :]), vi, ag, var_to_diff)
507+
ks = keys(ag)
508+
if !only_linear_algebraic
509+
for v in setdiff(solvable_variables, @view pivots[1:rank1])
510+
ag[v] = 0
511+
end
512+
else
513+
for v in setdiff(solvable_variables, @view pivots[1:rank1])
514+
if !(v in ks)
515+
ag[v] = 0
516+
end
517+
end
500518
end
501519

520+
lss! = lss(mm, pivots, ag)
502521
# Step 2.1: Go backwards, collecting eliminated variables and substituting
503522
# alias as we go.
504523
foreach(lss!, reverse(1:rank2))
@@ -524,6 +543,18 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
524543
reduced && while any(lss!, 1:rank2)
525544
end
526545

546+
return mm
547+
end
548+
549+
function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
550+
debug = false)
551+
# Step 1: Perform bareiss factorization on the adjacency matrix of the linear
552+
# subsystem of the system we're interested in.
553+
#
554+
nvars = ndsts(graph)
555+
ag = AliasGraph(nvars)
556+
mm = simple_aliases!(ag, graph, var_to_diff, mm_orig)
557+
527558
# Step 3: Handle differentiated variables
528559
# At this point, `var_to_diff` and `ag` form a tree structure like the
529560
# following:
@@ -547,7 +578,6 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
547578
# Note that since we always prefer the higher differentiated variable and
548579
# with a tie breaking strategy. The root variable (in this case `z`) is
549580
# always uniquely determined. Thus, the result is well-defined.
550-
nvars = ndsts(graph)
551581
diff_to_var = invview(var_to_diff)
552582
invag = SimpleDiGraph(nvars)
553583
for (v, (coeff, alias)) in pairs(ag)
@@ -592,8 +622,11 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
592622
current_coeff_level[] = (coeff, level + 1)
593623
end
594624
end
625+
max_lv = 0
595626
for (coeff, lv, t) in StatefulAliasBFS(RootedAliasTree(iag, r))
627+
max_lv = max(max_lv, lv)
596628
v = nodevalue(t)
629+
iszero(v) && continue
597630
processed[v] = true
598631
v == r && continue
599632
if lv < length(level_to_var)
@@ -604,8 +637,20 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
604637
current_coeff_level[] = coeff, lv
605638
extreme_var(var_to_diff, v, nothing, Val(false), callback = add_alias!)
606639
end
607-
for v in level_to_var
640+
max_lv > 0 || continue
641+
642+
set_v_zero! = let newag = newag
643+
v -> newag[v] = 0
644+
end
645+
for (i, v) in enumerate(level_to_var)
646+
_alias = get(ag, v, nothing)
608647
push!(irreducibles, v)
648+
if _alias !== nothing && iszero(_alias[1]) && i < length(level_to_var)
649+
# we have `x = 0`
650+
v = level_to_var[i + 1]
651+
extreme_var(var_to_diff, v, nothing, Val(false), callback = set_v_zero!)
652+
break
653+
end
609654
end
610655
if nlevels < (new_nlevels = length(level_to_var))
611656
for i in (nlevels + 1):new_nlevels
@@ -616,20 +661,26 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
616661
end
617662

618663
# There might be "cycles" like `D(x) = x`
619-
remove_aliases = BitSet()
620664
for v in irreducibles
665+
if v in keys(newag)
666+
newag[v] = nothing
667+
end
621668
if v in keys(ag)
622-
push!(remove_aliases, v)
669+
ag[v] = nothing
623670
end
624671
end
672+
for v in keys(ag)
673+
push!(irreducibles, v)
674+
end
625675

626-
for v in remove_aliases
627-
ag[v] = nothing
676+
for (v, (c, a)) in newag
677+
va = iszero(a) ? a : fullvars[a]
678+
@info "new alias" fullvars[v]=>(c, va)
628679
end
629680
newkeys = keys(newag)
630681
if !isempty(irreducibles)
631682
for (v, (c, a)) in ag
632-
(v in newkeys || a in newkeys || v in irreducibles) && continue
683+
(a in irreducibles || v in irreducibles) && continue
633684
if iszero(c)
634685
newag[v] = c
635686
else
@@ -640,8 +691,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
640691

641692
# We cannot use the `mm` from Bareiss because it doesn't consider
642693
# irreducibles
643-
mm_new = copy(mm_orig)
644-
mm = reduce!(mm_new, ag)
694+
mm_orig2 = reduce!(copy(mm_orig), ag)
695+
mm = mm_orig2
696+
mm = simple_aliases!(ag, graph, var_to_diff, mm_orig2, true, irreducibles)
645697
end
646698

647699
debug && for (v, (c, a)) in ag
@@ -664,7 +716,7 @@ function exactdiv(a::Integer, b)
664716
return d
665717
end
666718

667-
function locally_structure_simplify!(adj_row, pivot_var, ag, var_to_diff)
719+
function locally_structure_simplify!(adj_row, pivot_var, ag)
668720
pivot_val = adj_row[pivot_var]
669721
iszero(pivot_val) && return false
670722

0 commit comments

Comments
 (0)