@@ -420,25 +420,23 @@ count_nonzeros(a::AbstractArray) = count(!iszero, a)
420
420
# Here we have a guarantee that they won't, so we can make this identification
421
421
count_nonzeros (a:: SparseVector ) = nnz (a)
422
422
423
- function aag_bareiss! (graph, var_to_diff, mm_orig:: SparseMatrixCLIL ,
424
- only_linear_algebraic = false , irreducibles = () )
423
+ function aag_bareiss! (graph, var_to_diff, mm_orig:: SparseMatrixCLIL , only_algebraic,
424
+ irreducibles)
425
425
mm = copy (mm_orig)
426
426
is_linear_equations = falses (size (AsSubMatrix (mm_orig), 1 ))
427
- diff_to_var = invview (var_to_diff)
428
- islowest = let diff_to_var = diff_to_var
429
- v -> diff_to_var[v] === nothing
430
- end
431
- for e in mm_orig. nzrows
432
- is_linear_equations[e] = all (islowest, 𝑠neighbors (graph, e))
433
- end
434
-
435
- var_to_eq = let is_linear_equations = is_linear_equations, islowest = islowest,
436
- irreducibles = irreducibles
437
427
438
- maximal_matching (graph, eq -> is_linear_equations[eq],
439
- var -> islowest (var) && ! (var in irreducibles))
428
+ is_not_potential_state = isnothing .(var_to_diff)
429
+ for v in irreducibles
430
+ is_not_potential_state[v] = false
431
+ end
432
+ is_linear_variables = only_algebraic ? copy (is_not_potential_state) :
433
+ is_not_potential_state
434
+ for i in 𝑠vertices (graph)
435
+ is_linear_equations[i] && continue
436
+ for j in 𝑠neighbors (graph, i)
437
+ is_linear_variables[j] = false
438
+ end
440
439
end
441
- is_linear_variables = isa .(var_to_eq, Int)
442
440
solvable_variables = findall (is_linear_variables)
443
441
444
442
function do_bareiss! (M, Mold = nothing )
@@ -450,7 +448,13 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL,
450
448
r != = nothing && return r
451
449
rank1 = k - 1
452
450
end
453
- only_linear_algebraic && return nothing
451
+ if only_algebraic
452
+ if rank2 === nothing
453
+ r = find_masked_pivot (is_not_potential_state, M, k)
454
+ r != = nothing && return r
455
+ rank2 = k - 1
456
+ end
457
+ end
454
458
# TODO : It would be better to sort the variables by
455
459
# derivative order here to enable more elimination
456
460
# opportunities.
@@ -468,9 +472,10 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL,
468
472
end
469
473
bareiss_ops = ((M, i, j) -> nothing , myswaprows!,
470
474
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
471
- rank2 , = bareiss! (M, bareiss_ops; find_pivot = find_and_record_pivot)
475
+ rank3 , = bareiss! (M, bareiss_ops; find_pivot = find_and_record_pivot)
472
476
rank1 = something (rank1, rank2)
473
- (rank1, rank2, pivots)
477
+ rank2 = something (rank2, rank3)
478
+ (rank1, rank2, rank3, pivots)
474
479
end
475
480
476
481
return mm, solvable_variables, do_bareiss! (mm, mm_orig)
@@ -486,8 +491,7 @@ function lss(mm, pivots, ag)
486
491
end
487
492
end
488
493
489
- function simple_aliases! (ag, graph, var_to_diff, mm_orig, only_linear_algebraic = false ,
490
- irreducibles = ())
494
+ function simple_aliases! (ag, graph, var_to_diff, mm_orig, only_algebraic, irreducibles = ())
491
495
# Let `m = the number of linear equations` and `n = the number of
492
496
# variables`.
493
497
#
@@ -503,23 +507,15 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig, only_linear_algebraic
503
507
# contribute to the equations, but are not solved by the linear system. Note
504
508
# that the complete system may be larger than the linear subsystem and
505
509
# include variables that do not appear here.
506
- mm, solvable_variables, (rank1, rank2, pivots) = aag_bareiss! (graph, var_to_diff,
507
- mm_orig,
508
- only_linear_algebraic ,
509
- irreducibles)
510
+ mm, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss! (graph, var_to_diff,
511
+ mm_orig,
512
+ only_algebraic ,
513
+ irreducibles)
510
514
511
515
# Step 2: Simplify the system using the Bareiss factorization
512
516
ks = keys (ag)
513
- if ! only_linear_algebraic
514
- for v in setdiff (solvable_variables, @view pivots[1 : rank1])
515
- ag[v] = 0
516
- end
517
- else
518
- for v in setdiff (solvable_variables, @view pivots[1 : rank1])
519
- if ! (v in ks)
520
- ag[v] = 0
521
- end
522
- end
517
+ for v in setdiff (solvable_variables, @view pivots[1 : rank1])
518
+ ag[v] = 0
523
519
end
524
520
525
521
lss! = lss (mm, pivots, ag)
@@ -558,7 +554,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
558
554
#
559
555
nvars = ndsts (graph)
560
556
ag = AliasGraph (nvars)
561
- mm = simple_aliases! (ag, graph, var_to_diff, mm_orig)
557
+ mm = simple_aliases! (ag, graph, var_to_diff, mm_orig, false )
562
558
563
559
# Step 3: Handle differentiated variables
564
560
# At this point, `var_to_diff` and `ag` form a tree structure like the
@@ -665,39 +661,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
665
661
end
666
662
end
667
663
668
- # There might be "cycles" like `D(x) = x`
669
- for v in irreducibles
670
- if v in keys (newag)
671
- newag[v] = nothing
672
- end
673
- if v in keys (ag)
674
- ag[v] = nothing
675
- end
676
- end
677
- for v in keys (ag)
678
- push! (irreducibles, v)
679
- end
680
-
681
- for (v, (c, a)) in newag
682
- va = iszero (a) ? a : fullvars[a]
683
- @info " new alias" fullvars[v]=> (c, va)
684
- end
685
- newkeys = keys (newag)
686
664
if ! isempty (irreducibles)
687
- for (v, (c, a)) in ag
688
- (a in irreducibles || v in irreducibles) && continue
689
- if iszero (c)
690
- newag[v] = c
691
- else
692
- newag[v] = c => a
693
- end
694
- end
695
665
ag = newag
696
-
697
- # We cannot use the `mm` from Bareiss because it doesn't consider
698
- # irreducibles
699
- mm_orig2 = reduce! (copy (mm_orig), ag)
700
- mm = mm_orig2
666
+ mm_orig2 = isempty (ag) ? mm_orig : reduce! (copy (mm_orig), ag)
701
667
mm = simple_aliases! (ag, graph, var_to_diff, mm_orig2, true , irreducibles)
702
668
end
703
669
0 commit comments