@@ -50,6 +50,11 @@ function alias_elimination(sys; debug = false)
50
50
end
51
51
end
52
52
53
+ subs = Dict ()
54
+ for (v, (coeff, alias)) in pairs (ag)
55
+ subs[fullvars[v]] = iszero (coeff) ? 0 : coeff * fullvars[alias]
56
+ end
57
+
53
58
dels = Int[]
54
59
eqs = collect (equations (state))
55
60
for (ei, e) in enumerate (mm. nzrows)
@@ -67,10 +72,6 @@ function alias_elimination(sys; debug = false)
67
72
end
68
73
deleteat! (eqs, sort! (dels))
69
74
70
- subs = Dict ()
71
- for (v, (coeff, alias)) in pairs (ag)
72
- subs[fullvars[v]] = iszero (coeff) ? 0 : coeff * fullvars[alias]
73
- end
74
75
for (ieq, eq) in enumerate (eqs)
75
76
eqs[ieq] = substitute (eq, subs)
76
77
end
@@ -229,22 +230,29 @@ end
229
230
function reduce! (mm:: SparseMatrixCLIL , ag:: AliasGraph )
230
231
dels = Int[]
231
232
for (i, rs) in enumerate (mm. row_cols)
233
+ p = i == 7
232
234
rvals = mm. row_vals[i]
233
- for (j, c) in enumerate (rs)
235
+ j = 1
236
+ while j <= length (rs)
237
+ c = rs[j]
234
238
_alias = get (ag, c, nothing )
235
239
if _alias != = nothing
236
240
push! (dels, j)
237
241
coeff, alias = _alias
238
- iszero (coeff) && continue
242
+ iszero (coeff) && (j += 1 ; continue )
239
243
inc = coeff * rvals[j]
240
244
i = searchsortedfirst (rs, alias)
241
- if i > length (rvals)
242
- push! (rs, alias)
243
- push! (rvals, inc)
245
+ if i > length (rs) || rs[i] != alias
246
+ if i <= j
247
+ j += 1
248
+ end
249
+ insert! (rs, i, alias)
250
+ insert! (rvals, i, inc)
244
251
else
245
252
rvals[i] += inc
246
253
end
247
254
end
255
+ j += 1
248
256
end
249
257
deleteat! (rs, dels)
250
258
deleteat! (rvals, dels)
@@ -304,8 +312,9 @@ function Base.iterate(it::IAGNeighbors, state = nothing)
304
312
end
305
313
else
306
314
used_ag = true
307
- if (_n = get (ag, v, nothing )) != = nothing
308
- n = _n[2 ]
315
+ # We don't care about the alising value because we only use this to
316
+ # find the root of the tree.
317
+ if (_n = get (ag, v, nothing )) != = nothing && (n = _n[2 ]) > 0
309
318
if ! visited[n]
310
319
n, lv = extreme_var (var_to_diff, n, level)
311
320
extreme_var (var_to_diff, n, nothing , Val (false ), callback = callback!)
@@ -366,7 +375,7 @@ function Base.iterate(it::StatefulAliasBFS, queue = (eltype(it)[(1, 0, it.t)]))
366
375
coeff, lv, t = popfirst! (queue)
367
376
nextlv = lv + 1
368
377
for (coeff′, c) in children (t)
369
- # TODO : maybe fix the children iterator instead.
378
+ # FIXME : use the visited cache!
370
379
# A "cycle" might occur when we have `x ~ D(x)`.
371
380
nodevalue (c) == t. root && continue
372
381
# -1 <= coeff <= 1
@@ -383,7 +392,7 @@ function Base.iterate(c::RootedAliasChildren, s = nothing)
383
392
rat = c. t
384
393
@unpack iag, root = rat
385
394
@unpack ag, invag, var_to_diff, visited = iag
386
- (root = var_to_diff[root]) === nothing && return nothing
395
+ (iszero ( root) || (root = var_to_diff[root]) === nothing ) && return nothing
387
396
if s === nothing
388
397
stage = 1
389
398
it = iterate (neighbors (invag, root))
@@ -404,7 +413,7 @@ function Base.iterate(c::RootedAliasChildren, s = nothing)
404
413
it === nothing && return nothing
405
414
e, ns = it
406
415
# c * a = b <=> a = c * b when -1 <= c <= 1
407
- return (ag[e], RootedAliasTree (iag, e)), (stage, iterate (invag , ns))
416
+ return (ag[e][ 1 ] , RootedAliasTree (iag, e)), (stage, iterate (it , ns))
408
417
end
409
418
410
419
count_nonzeros (a:: AbstractArray ) = count (! iszero, a)
@@ -413,22 +422,21 @@ count_nonzeros(a::AbstractArray) = count(!iszero, a)
413
422
# Here we have a guarantee that they won't, so we can make this identification
414
423
count_nonzeros (a:: SparseVector ) = nnz (a)
415
424
416
- function aag_bareiss! (graph, var_to_diff, mm_orig:: SparseMatrixCLIL )
425
+ function aag_bareiss! (graph, var_to_diff, mm_orig:: SparseMatrixCLIL , only_linear_algebraic = false , irreducibles = () )
417
426
mm = copy (mm_orig)
418
427
is_linear_equations = falses (size (AsSubMatrix (mm_orig), 1 ))
428
+ diff_to_var = invview (var_to_diff)
429
+ islowest = let diff_to_var = diff_to_var
430
+ v -> diff_to_var[v] === nothing
431
+ end
419
432
for e in mm_orig. nzrows
420
- is_linear_equations[e] = true
433
+ is_linear_equations[e] = all (islowest, 𝑠neighbors (graph, e))
421
434
end
422
435
423
- # Variables that are highest order differentiated cannot be states of an ODE
424
- is_not_potential_state = isnothing .(var_to_diff)
425
- is_linear_variables = copy (is_not_potential_state)
426
- for i in 𝑠vertices (graph)
427
- is_linear_equations[i] && continue
428
- for j in 𝑠neighbors (graph, i)
429
- is_linear_variables[j] = false
430
- end
436
+ var_to_eq = let is_linear_equations = is_linear_equations, islowest = islowest, irreducibles = irreducibles
437
+ maximal_matching (graph, eq -> is_linear_equations[eq], var -> islowest (var) && ! (var in irreducibles))
431
438
end
439
+ is_linear_variables = isa .(var_to_eq, Int)
432
440
solvable_variables = findall (is_linear_variables)
433
441
434
442
function do_bareiss! (M, Mold = nothing )
@@ -440,11 +448,10 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
440
448
r != = nothing && return r
441
449
rank1 = k - 1
442
450
end
443
- if rank2 === nothing
444
- r = find_masked_pivot (is_not_potential_state, M, k)
445
- r != = nothing && return r
446
- rank2 = k - 1
447
- end
451
+ only_linear_algebraic && return nothing
452
+ # TODO : It would be better to sort the variables by
453
+ # derivative order here to enable more elimination
454
+ # opportunities.
448
455
return find_masked_pivot (nothing , M, k)
449
456
end
450
457
function find_and_record_pivot (M, k)
@@ -459,46 +466,58 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
459
466
end
460
467
bareiss_ops = ((M, i, j) -> nothing , myswaprows!,
461
468
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
462
- rank3, = bareiss! (M, bareiss_ops; find_pivot = find_and_record_pivot)
463
- rank1 = something (rank1, rank3)
464
- rank2 = something (rank2, rank3)
465
- (rank1, rank2, rank3, pivots)
469
+ rank2, = bareiss! (M, bareiss_ops; find_pivot = find_and_record_pivot)
470
+ rank1 = something (rank1, rank2)
471
+ (rank1, rank2, pivots)
466
472
end
467
473
468
474
return mm, solvable_variables, do_bareiss! (mm, mm_orig)
469
475
end
470
476
471
- function alias_eliminate_graph! (graph, var_to_diff, mm_orig:: SparseMatrixCLIL ;
472
- debug = false )
473
- # Step 1: Perform bareiss factorization on the adjacency matrix of the linear
474
- # subsystem of the system we're interested in.
475
- #
477
+ # Kind of like the backward substitution, but we don't actually rely on it
478
+ # being lower triangular. We eliminate a variable if there are at most 2
479
+ # variables left after the substitution.
480
+ function lss (mm, pivots, ag)
481
+ ei -> let mm = mm, pivots = pivots, ag = ag
482
+ vi = pivots[ei]
483
+ locally_structure_simplify! ((@view mm[ei, :]), vi, ag)
484
+ end
485
+ end
486
+
487
+ function simple_aliases! (ag, graph, var_to_diff, mm_orig, only_linear_algebraic = false , irreducibles = ())
476
488
# Let `m = the number of linear equations` and `n = the number of
477
489
# variables`.
478
490
#
479
491
# `do_bareiss` conceptually gives us this system:
480
- # rank1 | [ M₁₁ M₁₂ | M₁₃ M₁₄ ] [v₁] = [0]
481
- # rank2 | [ 0 M₂₂ | M₂₃ M₂₄ ] P [v₂] = [0]
482
- # -------------------|------------------------
483
- # rank3 | [ 0 0 | M₃₃ M₃₄ ] [v₃] = [0]
484
- # [ 0 0 | 0 0 ] [v₄] = [0]
485
- mm, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss! (graph, var_to_diff,
486
- mm_orig)
492
+ # rank1 | [ M₁₁ M₁₂ | M₁₃ ] [v₁] = [0]
493
+ # rank2 | [ 0 M₂₂ | M₂₃ ] P [v₂] = [0]
494
+ # -------------------|-------------------
495
+ # [ 0 0 | 0 ] [v₃] = [0]
496
+
497
+ # Where `v₁` are the purely linear algebraic variables (i.e. those that only
498
+ # appear in linear algebraic equations), `v₂` are the variables that may be
499
+ # potentially solved by the linear system, and `v₃` are the variables that
500
+ # contribute to the equations, but are not solved by the linear system. Note
501
+ # that the complete system may be larger than the linear subsystem and
502
+ # include variables that do not appear here.
503
+ mm, solvable_variables, (rank1, rank2, pivots) = aag_bareiss! (graph, var_to_diff,
504
+ mm_orig, only_linear_algebraic, irreducibles)
487
505
488
506
# Step 2: Simplify the system using the Bareiss factorization
489
- ag = AliasGraph ( size (mm, 2 ) )
490
- for v in setdiff (solvable_variables, @view pivots[ 1 : rank1])
491
- ag[v] = 0
492
- end
493
-
494
- # Kind of like the backward substitution, but we don't actually rely on it
495
- # being lower triangular. We eliminate a variable if there are at most 2
496
- # variables left after the substitution.
497
- function lss! (ei :: Integer )
498
- vi = pivots[ei]
499
- locally_structure_simplify! (( @view mm[ei, :]), vi, ag, var_to_diff)
507
+ ks = keys (ag )
508
+ if ! only_linear_algebraic
509
+ for v in setdiff (solvable_variables, @view pivots[ 1 : rank1])
510
+ ag[v] = 0
511
+ end
512
+ else
513
+ for v in setdiff (solvable_variables, @view pivots[ 1 : rank1])
514
+ if ! (v in ks)
515
+ ag[v] = 0
516
+ end
517
+ end
500
518
end
501
519
520
+ lss! = lss (mm, pivots, ag)
502
521
# Step 2.1: Go backwards, collecting eliminated variables and substituting
503
522
# alias as we go.
504
523
foreach (lss!, reverse (1 : rank2))
@@ -524,6 +543,18 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
524
543
reduced && while any (lss!, 1 : rank2)
525
544
end
526
545
546
+ return mm
547
+ end
548
+
549
+ function alias_eliminate_graph! (graph, var_to_diff, mm_orig:: SparseMatrixCLIL ;
550
+ debug = false )
551
+ # Step 1: Perform bareiss factorization on the adjacency matrix of the linear
552
+ # subsystem of the system we're interested in.
553
+ #
554
+ nvars = ndsts (graph)
555
+ ag = AliasGraph (nvars)
556
+ mm = simple_aliases! (ag, graph, var_to_diff, mm_orig)
557
+
527
558
# Step 3: Handle differentiated variables
528
559
# At this point, `var_to_diff` and `ag` form a tree structure like the
529
560
# following:
@@ -547,7 +578,6 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
547
578
# Note that since we always prefer the higher differentiated variable and
548
579
# with a tie breaking strategy. The root variable (in this case `z`) is
549
580
# always uniquely determined. Thus, the result is well-defined.
550
- nvars = ndsts (graph)
551
581
diff_to_var = invview (var_to_diff)
552
582
invag = SimpleDiGraph (nvars)
553
583
for (v, (coeff, alias)) in pairs (ag)
@@ -592,8 +622,11 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
592
622
current_coeff_level[] = (coeff, level + 1 )
593
623
end
594
624
end
625
+ max_lv = 0
595
626
for (coeff, lv, t) in StatefulAliasBFS (RootedAliasTree (iag, r))
627
+ max_lv = max (max_lv, lv)
596
628
v = nodevalue (t)
629
+ iszero (v) && continue
597
630
processed[v] = true
598
631
v == r && continue
599
632
if lv < length (level_to_var)
@@ -604,8 +637,20 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
604
637
current_coeff_level[] = coeff, lv
605
638
extreme_var (var_to_diff, v, nothing , Val (false ), callback = add_alias!)
606
639
end
607
- for v in level_to_var
640
+ max_lv > 0 || continue
641
+
642
+ set_v_zero! = let newag = newag
643
+ v -> newag[v] = 0
644
+ end
645
+ for (i, v) in enumerate (level_to_var)
646
+ _alias = get (ag, v, nothing )
608
647
push! (irreducibles, v)
648
+ if _alias != = nothing && iszero (_alias[1 ]) && i < length (level_to_var)
649
+ # we have `x = 0`
650
+ v = level_to_var[i + 1 ]
651
+ extreme_var (var_to_diff, v, nothing , Val (false ), callback = set_v_zero!)
652
+ break
653
+ end
609
654
end
610
655
if nlevels < (new_nlevels = length (level_to_var))
611
656
for i in (nlevels + 1 ): new_nlevels
@@ -616,20 +661,26 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
616
661
end
617
662
618
663
# There might be "cycles" like `D(x) = x`
619
- remove_aliases = BitSet ()
620
664
for v in irreducibles
665
+ if v in keys (newag)
666
+ newag[v] = nothing
667
+ end
621
668
if v in keys (ag)
622
- push! (remove_aliases, v)
669
+ ag[v] = nothing
623
670
end
624
671
end
672
+ for v in keys (ag)
673
+ push! (irreducibles, v)
674
+ end
625
675
626
- for v in remove_aliases
627
- ag[v] = nothing
676
+ for (v, (c, a)) in newag
677
+ va = iszero (a) ? a : fullvars[a]
678
+ @info " new alias" fullvars[v]=> (c, va)
628
679
end
629
680
newkeys = keys (newag)
630
681
if ! isempty (irreducibles)
631
682
for (v, (c, a)) in ag
632
- (v in newkeys || a in newkeys || v in irreducibles) && continue
683
+ (a in irreducibles || v in irreducibles) && continue
633
684
if iszero (c)
634
685
newag[v] = c
635
686
else
@@ -640,8 +691,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
640
691
641
692
# We cannot use the `mm` from Bareiss because it doesn't consider
642
693
# irreducibles
643
- mm_new = copy (mm_orig)
644
- mm = reduce! (mm_new, ag)
694
+ mm_orig2 = reduce! (copy (mm_orig), ag)
695
+ mm = mm_orig2
696
+ mm = simple_aliases! (ag, graph, var_to_diff, mm_orig2, true , irreducibles)
645
697
end
646
698
647
699
debug && for (v, (c, a)) in ag
@@ -664,7 +716,7 @@ function exactdiv(a::Integer, b)
664
716
return d
665
717
end
666
718
667
- function locally_structure_simplify! (adj_row, pivot_var, ag, var_to_diff )
719
+ function locally_structure_simplify! (adj_row, pivot_var, ag)
668
720
pivot_val = adj_row[pivot_var]
669
721
iszero (pivot_val) && return false
670
722
0 commit comments