Skip to content

Commit 8978fc0

Browse files
YingboMaKeno
andcommitted
Add back rank2 (highest diffed variables)
Co-authored-by: Keno Fischer <[email protected]>
1 parent 4b66189 commit 8978fc0

File tree

3 files changed

+24
-19
lines changed

3 files changed

+24
-19
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ export tearing_assignments, tearing_substitution
5454
export torn_system_jacobian_sparsity
5555
export full_equations
5656
export but_ordered_incidence, lowest_order_variable_mask, highest_order_variable_mask
57+
export computed_highest_diff_variables
5758

5859
include("utils.jl")
5960
include("pantelides.jl")

src/structural_transformation/pantelides.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching)
7070
end
7171

7272
"""
73-
computed_highest_diff_variables(var_to_diff)
73+
computed_highest_diff_variables(structure)
7474
7575
Computes which variables are the "highest-differentiated" for purposes of
7676
pantelides. Ordinarily this is relatively straightforward. However, in our

src/systems/alias_elimination.jl

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ function find_linear_variables(graph, linear_equations, var_to_diff, irreducible
264264
return linear_variables
265265
end
266266

267-
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti}
267+
function aag_bareiss!(structure, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti}
268+
@unpack graph, var_to_diff = structure
268269
mm = copy(mm_orig)
269270
linear_equations_set = BitSet(mm_orig.nzrows)
270271

@@ -279,6 +280,7 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL{T, Ti}) wher
279280
v -> var_to_diff[v] === nothing === invview(var_to_diff)[v]
280281
end
281282
is_linear_variables = is_algebraic.(1:length(var_to_diff))
283+
is_highest_diff = computed_highest_diff_variables(structure)
282284
for i in 𝑠vertices(graph)
283285
# only consider linear algebraic equations
284286
(i in linear_equations_set && all(is_algebraic, 𝑠neighbors(graph, i))) &&
@@ -291,25 +293,31 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL{T, Ti}) wher
291293

292294
local bar
293295
try
294-
bar = do_bareiss!(mm, mm_orig, is_linear_variables)
296+
bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff)
295297
catch e
296298
e isa OverflowError || rethrow(e)
297299
mm = convert(SparseMatrixCLIL{BigInt, Ti}, mm_orig)
298-
bar = do_bareiss!(mm, mm_orig, is_linear_variables)
300+
bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff)
299301
end
300302

301303
return mm, solvable_variables, bar
302304
end
303305

304-
function do_bareiss!(M, Mold, is_linear_variables)
306+
function do_bareiss!(M, Mold, is_linear_variables, is_highest_diff)
305307
rank1r = Ref{Union{Nothing, Int}}(nothing)
308+
rank2r = Ref{Union{Nothing, Int}}(nothing)
306309
find_pivot = let rank1r = rank1r
307310
(M, k) -> begin
308311
if rank1r[] === nothing
309312
r = find_masked_pivot(is_linear_variables, M, k)
310313
r !== nothing && return r
311314
rank1r[] = k - 1
312315
end
316+
if rank2r[] === nothing
317+
r = find_masked_pivot(is_highest_diff, M, k)
318+
r !== nothing && return r
319+
rank2r[] = k - 1
320+
end
313321
# TODO: It would be better to sort the variables by
314322
# derivative order here to enable more elimination
315323
# opportunities.
@@ -334,15 +342,19 @@ function do_bareiss!(M, Mold, is_linear_variables)
334342
bareiss_ops = ((M, i, j) -> nothing, myswaprows!,
335343
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
336344

337-
rank2, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
345+
rank3, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
346+
rank2 = something(rank2r[], rank3)
338347
rank1 = something(rank1r[], rank2)
339-
(rank1, rank2, pivots)
348+
(rank1, rank2, rank3, pivots)
340349
end
341350

342-
function simple_aliases!(ils, graph, solvable_graph, eq_to_diff, var_to_diff)
343-
ils, solvable_variables, (rank1, rank2, pivots) = aag_bareiss!(graph,
344-
var_to_diff,
345-
ils)
351+
function alias_eliminate_graph!(state::TransformationState, ils::SparseMatrixCLIL)
352+
@unpack structure = state
353+
@unpack graph, solvable_graph, var_to_diff, eq_to_diff = state.structure
354+
# Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
355+
# subsystem of the system we're interested in.
356+
#
357+
ils, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss!(structure, ils)
346358

347359
## Step 2: Simplify the system using the Bareiss factorization
348360
rk1vars = BitSet(@view pivots[1:rank1])
@@ -362,14 +374,6 @@ function simple_aliases!(ils, graph, solvable_graph, eq_to_diff, var_to_diff)
362374
return ils
363375
end
364376

365-
function alias_eliminate_graph!(state::TransformationState, ils::SparseMatrixCLIL)
366-
@unpack graph, solvable_graph, var_to_diff, eq_to_diff = state.structure
367-
# Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
368-
# subsystem of the system we're interested in.
369-
#
370-
return simple_aliases!(ils, graph, solvable_graph, eq_to_diff, var_to_diff)
371-
end
372-
373377
function exactdiv(a::Integer, b)
374378
d, r = divrem(a, b)
375379
@assert r == 0

0 commit comments

Comments
 (0)