@@ -426,98 +426,82 @@ function find_linear_variables(graph, linear_equations, var_to_diff, irreducible
426
426
return linear_variables
427
427
end
428
428
429
- function aag_bareiss! (graph, var_to_diff, mm_orig:: SparseMatrixCLIL , irreducibles = () )
429
+ function aag_bareiss! (graph, var_to_diff, mm_orig:: SparseMatrixCLIL )
430
430
mm = copy (mm_orig)
431
- linear_equations = mm_orig. nzrows
432
- is_linear_equations = falses (size (AsSubMatrix (mm_orig), 1 ))
433
- for e in mm_orig. nzrows
434
- is_linear_equations[e] = true
435
- end
431
+ linear_equations_set = BitSet (mm_orig. nzrows)
436
432
437
- # If linear highest differentiated variables cannot be assigned to a pivot,
438
- # then we can set it to zero. We use `rank1` to track this.
439
- #
440
- # We only use alias graph to record reducible variables. We use `rank2` to
441
- # track this.
433
+ # All unassigned (not a pivot) algebraic variables that only appears in
434
+ # linear algebraic equations can be set to 0.
442
435
#
443
436
# For all the other variables, we can update the original system with
444
437
# Bareiss'ed coefficients as Gaussian elimination is nullspace perserving
445
438
# and we are only working on linear homogeneous subsystem.
446
- is_reducible = trues (length (var_to_diff))
447
- # TODO : what's the correct criterion here?
448
- is_linear_variables = isnothing .(var_to_diff) .& isnothing .(invview (var_to_diff))
439
+
440
+ is_algebraic = let var_to_diff = var_to_diff
441
+ v -> var_to_diff[v] === nothing === invview (var_to_diff)[v]
442
+ end
443
+ is_linear_variables = is_algebraic .(1 : length (var_to_diff))
449
444
for i in 𝑠vertices (graph)
450
- is_linear_equations[i] && continue
445
+ # only consider linear algebraic equations
446
+ (i in linear_equations_set && all (is_algebraic, 𝑠neighbors (graph, i))) && continue
451
447
for j in 𝑠neighbors (graph, i)
452
448
is_linear_variables[j] = false
453
449
end
454
450
end
455
451
solvable_variables = findall (is_linear_variables)
456
452
457
- function do_bareiss! (M, Mold = nothing )
458
- rank1 = rank2 = nothing
459
- pivots = Int[]
460
- function find_pivot (M, k)
461
- if rank1 === nothing
453
+ return mm, solvable_variables, do_bareiss! (mm, mm_orig, is_linear_variables)
454
+ end
455
+
456
+ function do_bareiss! (M, Mold, is_linear_variables)
457
+ rank1r = Ref {Union{Nothing, Int}} (nothing )
458
+ find_pivot = let rank1r = rank1r
459
+ (M, k) -> begin
460
+ if rank1r[] === nothing
462
461
r = find_masked_pivot (is_linear_variables, M, k)
463
462
r != = nothing && return r
464
- rank1 = k - 1
465
- end
466
- if rank2 === nothing
467
- r = find_masked_pivot (is_reducible, M, k)
468
- r != = nothing && return r
469
- rank2 = k - 1
463
+ rank1r[] = k - 1
470
464
end
471
465
# TODO : It would be better to sort the variables by
472
466
# derivative order here to enable more elimination
473
467
# opportunities.
474
468
return find_masked_pivot (nothing , M, k)
475
469
end
476
- function find_and_record_pivot (M, k)
470
+ end
471
+ pivots = Int[]
472
+ find_and_record_pivot = let pivots = pivots
473
+ (M, k) -> begin
477
474
r = find_pivot (M, k)
478
475
r === nothing && return nothing
479
476
push! (pivots, r[1 ][2 ])
480
477
return r
481
478
end
482
- function myswaprows! (M, i, j)
479
+ end
480
+ myswaprows! = let Mold = Mold
481
+ (M, i, j) -> begin
483
482
Mold != = nothing && swaprows! (Mold, i, j)
484
483
swaprows! (M, i, j)
485
484
end
486
- bareiss_ops = ((M, i, j) -> nothing , myswaprows!,
487
- bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
488
- rank3, = bareiss! (M, bareiss_ops; find_pivot = find_and_record_pivot)
489
- rank2 = something (rank2, rank3)
490
- rank1 = something (rank1, rank2)
491
- (rank1, rank2, rank3, pivots)
492
485
end
493
-
494
- return mm, solvable_variables, do_bareiss! (mm, mm_orig)
486
+ bareiss_ops = ((M, i, j) -> nothing , myswaprows!,
487
+ bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
488
+ rank2, = bareiss! (M, bareiss_ops; find_pivot = find_and_record_pivot)
489
+ rank1 = something (rank1r[], rank2)
490
+ (rank1, rank2, pivots)
495
491
end
496
492
497
493
# Kind of like the backward substitution, but we don't actually rely on it
498
494
# being lower triangular. We eliminate a variable if there are at most 2
499
495
# variables left after the substitution.
500
- function lss (mm, pivots, ag )
496
+ function lss (mm, ag, pivots )
501
497
ei -> let mm = mm, pivots = pivots, ag = ag
502
- vi = pivots[ei]
498
+ vi = pivots === nothing ? nothing : pivots [ei]
503
499
locally_structure_simplify! ((@view mm[ei, :]), vi, ag)
504
500
end
505
501
end
506
502
507
- function simple_aliases! (ag, graph, var_to_diff, mm_orig, irreducibles = ())
508
- mm, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss! (graph, var_to_diff,
509
- mm_orig,
510
- irreducibles)
511
-
512
- # Step 2: Simplify the system using the Bareiss factorization
513
- rk1vars = BitSet (@view pivots[1 : rank1])
514
- for v in solvable_variables
515
- v in rk1vars && continue
516
- ag[v] = 0
517
- end
518
-
519
- echelon_mm = copy (mm)
520
- lss! = lss (mm, pivots, ag)
503
+ function reduce! (mm, mm_orig, ag, rank2, pivots = nothing )
504
+ lss! = lss (mm, ag, pivots)
521
505
# Step 2.1: Go backwards, collecting eliminated variables and substituting
522
506
# alias as we go.
523
507
foreach (lss!, reverse (1 : rank2))
@@ -543,6 +527,20 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig, irreducibles = ())
543
527
reduced && while any (lss!, 1 : rank2)
544
528
end
545
529
530
+ return mm
531
+ end
532
+
533
+ function simple_aliases! (ag, graph, var_to_diff, mm_orig)
534
+ echelon_mm, solvable_variables, (rank1, rank2, pivots) = aag_bareiss! (graph, var_to_diff, mm_orig)
535
+
536
+ # Step 2: Simplify the system using the Bareiss factorization
537
+ rk1vars = BitSet (@view pivots[1 : rank1])
538
+ for v in solvable_variables
539
+ v in rk1vars && continue
540
+ ag[v] = 0
541
+ end
542
+
543
+ mm = reduce! (copy (echelon_mm), mm_orig, ag, rank2, pivots)
546
544
return mm, echelon_mm
547
545
end
548
546
@@ -587,13 +585,13 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
587
585
(v, w) -> var_to_diff[v] == w || var_to_diff[w] == v
588
586
end
589
587
diff_aliases = Vector{Pair{Int, Int}}[]
588
+ stem = Int[]
589
+ stem_set = BitSet ()
590
590
for (v, dv) in enumerate (var_to_diff)
591
591
processed[v] && continue
592
592
(dv === nothing && diff_to_var[v] === nothing ) && continue
593
593
r = find_root! (dls, g, v)
594
594
prev_r = - 1
595
- stem = Int[]
596
- stem_set = BitSet ()
597
595
for _ in 1 : 10_000 # just to make sure that we don't stuck in an infinite loop
598
596
reach₌ = Pair{Int, Int}[]
599
597
r === nothing || for n in neighbors (eqg, r)
@@ -650,6 +648,21 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
650
648
# edges.
651
649
weighted_transitiveclosure! (eqg)
652
650
# Canonicalize by preferring the lower differentiated variable
651
+ # If we have the system
652
+ # ```
653
+ # D(x) ~ x
654
+ # D(x) + D(y) ~ 0
655
+ # ```
656
+ # preferring the lower variable would lead to
657
+ # ```
658
+ # D(x) ~ x <== added back because `x := D(x)` removes `D(x)`
659
+ # D(y) ~ -x
660
+ # ```
661
+ # while preferring the higher variable would lead to
662
+ # ```
663
+ # D(x) + D(y) ~ 0
664
+ # ```
665
+ # which is not correct.
653
666
for i in 1 : (length (stem) - 1 )
654
667
r = stem[i]
655
668
for dr in @view stem[(i + 1 ): end ]
@@ -712,6 +725,8 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
712
725
end
713
726
empty! (dls. visited)
714
727
empty! (diff_aliases)
728
+ empty! (stem)
729
+ empty! (stem_set)
715
730
end
716
731
# update `dag`
717
732
for k in keys (dag)
@@ -720,7 +735,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
720
735
721
736
# Step 4: Merge dag and ag
722
737
removed_aliases = BitSet ()
723
- freshag = AliasGraph (nvars)
738
+ merged_ag = AliasGraph (nvars)
724
739
for (v, (c, a)) in dag
725
740
# D(x) ~ D(y) cannot be removed if x and y are not aliases
726
741
if v != a && ! iszero (a)
@@ -732,23 +747,21 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
732
747
vv === nothing && break
733
748
if ! (haskey (dag, vv) && dag[vv][2 ] == diff_to_var[aa])
734
749
push! (removed_aliases, vv′)
735
- @goto SKIP_FRESHAG
750
+ @goto SKIP_merged_ag
736
751
end
737
752
end
738
753
end
739
- freshag [v] = c => a
740
- @label SKIP_FRESHAG
754
+ merged_ag [v] = c => a
755
+ @label SKIP_merged_ag
741
756
push! (removed_aliases, a)
742
757
end
743
758
for (v, (c, a)) in ag
744
759
(processed[v] || (! iszero (a) && processed[a])) && continue
745
760
v in removed_aliases && continue
746
- freshag[v] = c => a
747
- end
748
- if freshag != ag
749
- ag = freshag
750
- mm = reduce! (copy (echelon_mm), ag)
761
+ merged_ag[v] = c => a
751
762
end
763
+ ag = merged_ag
764
+ mm = reduce! (copy (echelon_mm), mm_orig, ag, size (echelon_mm, 1 ))
752
765
753
766
# Step 5: Reflect our update decisions back into the graph, and make sure
754
767
# that the RHS of observable variables are defined.
@@ -776,7 +789,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
776
789
ag = finalag
777
790
778
791
if needs_update
779
- mm = reduce! (copy (echelon_mm), ag )
792
+ mm = reduce! (copy (echelon_mm), mm_orig, ag, size (echelon_mm, 1 ) )
780
793
for (ei, e) in enumerate (mm. nzrows)
781
794
set_neighbors! (graph, e, mm. row_cols[ei])
782
795
end
@@ -803,6 +816,7 @@ function exactdiv(a::Integer, b)
803
816
end
804
817
805
818
function locally_structure_simplify! (adj_row, pivot_var, ag)
819
+ # If `pivot_var === nothing`, then we only apply `ag` to `adj_row`
806
820
if pivot_var === nothing
807
821
pivot_val = nothing
808
822
else
@@ -857,7 +871,7 @@ function locally_structure_simplify!(adj_row, pivot_var, ag)
857
871
else
858
872
dropzeros! (adj_row)
859
873
end
860
- return true
874
+ return false
861
875
end
862
876
# If there were only one or two terms left in the equation (including the
863
877
# pivot variable). We can eliminate the pivot variable. Note that when
0 commit comments