@@ -264,7 +264,8 @@ function find_linear_variables(graph, linear_equations, var_to_diff, irreducible
264
264
return linear_variables
265
265
end
266
266
267
- function aag_bareiss! (graph, var_to_diff, mm_orig:: SparseMatrixCLIL{T, Ti} ) where {T, Ti}
267
+ function aag_bareiss! (structure, mm_orig:: SparseMatrixCLIL{T, Ti} ) where {T, Ti}
268
+ @unpack graph, var_to_diff = structure
268
269
mm = copy (mm_orig)
269
270
linear_equations_set = BitSet (mm_orig. nzrows)
270
271
@@ -279,6 +280,7 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL{T, Ti}) wher
279
280
v -> var_to_diff[v] === nothing === invview (var_to_diff)[v]
280
281
end
281
282
is_linear_variables = is_algebraic .(1 : length (var_to_diff))
283
+ is_highest_diff = computed_highest_diff_variables (structure)
282
284
for i in 𝑠vertices (graph)
283
285
# only consider linear algebraic equations
284
286
(i in linear_equations_set && all (is_algebraic, 𝑠neighbors (graph, i))) &&
@@ -291,25 +293,31 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL{T, Ti}) wher
291
293
292
294
local bar
293
295
try
294
- bar = do_bareiss! (mm, mm_orig, is_linear_variables)
296
+ bar = do_bareiss! (mm, mm_orig, is_linear_variables, is_highest_diff )
295
297
catch e
296
298
e isa OverflowError || rethrow (e)
297
299
mm = convert (SparseMatrixCLIL{BigInt, Ti}, mm_orig)
298
- bar = do_bareiss! (mm, mm_orig, is_linear_variables)
300
+ bar = do_bareiss! (mm, mm_orig, is_linear_variables, is_highest_diff )
299
301
end
300
302
301
303
return mm, solvable_variables, bar
302
304
end
303
305
304
- function do_bareiss! (M, Mold, is_linear_variables)
306
+ function do_bareiss! (M, Mold, is_linear_variables, is_highest_diff )
305
307
rank1r = Ref {Union{Nothing, Int}} (nothing )
308
+ rank2r = Ref {Union{Nothing, Int}} (nothing )
306
309
find_pivot = let rank1r = rank1r
307
310
(M, k) -> begin
308
311
if rank1r[] === nothing
309
312
r = find_masked_pivot (is_linear_variables, M, k)
310
313
r != = nothing && return r
311
314
rank1r[] = k - 1
312
315
end
316
+ if rank2r[] === nothing
317
+ r = find_masked_pivot (is_highest_diff, M, k)
318
+ r != = nothing && return r
319
+ rank2r[] = k - 1
320
+ end
313
321
# TODO : It would be better to sort the variables by
314
322
# derivative order here to enable more elimination
315
323
# opportunities.
@@ -334,15 +342,19 @@ function do_bareiss!(M, Mold, is_linear_variables)
334
342
bareiss_ops = ((M, i, j) -> nothing , myswaprows!,
335
343
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
336
344
337
- rank2, = bareiss! (M, bareiss_ops; find_pivot = find_and_record_pivot)
345
+ rank3, = bareiss! (M, bareiss_ops; find_pivot = find_and_record_pivot)
346
+ rank2 = something (rank2r[], rank3)
338
347
rank1 = something (rank1r[], rank2)
339
- (rank1, rank2, pivots)
348
+ (rank1, rank2, rank3, pivots)
340
349
end
341
350
342
- function simple_aliases! (ils, graph, solvable_graph, eq_to_diff, var_to_diff)
343
- ils, solvable_variables, (rank1, rank2, pivots) = aag_bareiss! (graph,
344
- var_to_diff,
345
- ils)
351
+ function alias_eliminate_graph! (state:: TransformationState , ils:: SparseMatrixCLIL )
352
+ @unpack structure = state
353
+ @unpack graph, solvable_graph, var_to_diff, eq_to_diff = state. structure
354
+ # Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
355
+ # subsystem of the system we're interested in.
356
+ #
357
+ ils, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss! (structure, ils)
346
358
347
359
# # Step 2: Simplify the system using the Bareiss factorization
348
360
rk1vars = BitSet (@view pivots[1 : rank1])
@@ -362,14 +374,6 @@ function simple_aliases!(ils, graph, solvable_graph, eq_to_diff, var_to_diff)
362
374
return ils
363
375
end
364
376
365
- function alias_eliminate_graph! (state:: TransformationState , ils:: SparseMatrixCLIL )
366
- @unpack graph, solvable_graph, var_to_diff, eq_to_diff = state. structure
367
- # Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
368
- # subsystem of the system we're interested in.
369
- #
370
- return simple_aliases! (ils, graph, solvable_graph, eq_to_diff, var_to_diff)
371
- end
372
-
373
377
function exactdiv (a:: Integer , b)
374
378
d, r = divrem (a, b)
375
379
@assert r == 0
0 commit comments