@@ -57,13 +57,7 @@ function alias_elimination(sys)
57
57
58
58
newstates = []
59
59
for j in eachindex (fullvars)
60
- if j in keys (ag)
61
- # Put back equations for alias eliminated dervars
62
- if isdervar (state. structure, j) &&
63
- ! (invview (state. structure. var_to_diff)[j] in keys (ag))
64
- push! (eqs, fullvars[j] ~ subs[fullvars[j]])
65
- end
66
- else
60
+ if ! (j in keys (ag))
67
61
isdervar (state. structure, j) || push! (newstates, fullvars[j])
68
62
end
69
63
end
@@ -198,7 +192,6 @@ struct AliasGraphKeySet <: AbstractSet{Int}
198
192
end
199
193
Base. keys (ag:: AliasGraph ) = AliasGraphKeySet (ag)
200
194
Base. iterate (agk:: AliasGraphKeySet , state... ) = Base. iterate (agk. ag. eliminated, state... )
201
- Base. length (agk:: AliasGraphKeySet ) = Base. length (agk. ag. eliminated)
202
195
function Base. in (i:: Int , agk:: AliasGraphKeySet )
203
196
aliasto = agk. ag. aliasto
204
197
1 <= i <= length (aliasto) && aliasto[i] != = nothing
@@ -217,11 +210,9 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
217
210
is_linear_equations[e] = true
218
211
end
219
212
220
- # For now, only consider variables linear that are not differentiated.
221
- # We could potentially apply the same logic to variables whose derivative
222
- # is also linear, but that's a TODO .
223
- diff_to_var = invview (var_to_diff)
224
- is_linear_variables = .& (isnothing .(var_to_diff), isnothing .(diff_to_var))
213
+ # Variables that are highest order differentiated cannot be states of an ODE
214
+ is_not_potential_state = isnothing .(var_to_diff)
215
+ is_linear_variables = copy (is_not_potential_state)
225
216
for i in 𝑠vertices (graph)
226
217
is_linear_equations[i] && continue
227
218
for j in 𝑠neighbors (graph, i)
@@ -239,9 +230,11 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
239
230
r != = nothing && return r
240
231
rank1 = k - 1
241
232
end
242
- # TODO : It would be better to sort the variables by
243
- # derivative order here to enable more elimination
244
- # opportunities.
233
+ if rank2 === nothing
234
+ r = find_masked_pivot (is_not_potential_state, M, k)
235
+ r != = nothing && return r
236
+ rank2 = k - 1
237
+ end
245
238
return find_masked_pivot (nothing , M, k)
246
239
end
247
240
function find_and_record_pivot (M, k)
@@ -256,9 +249,10 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
256
249
end
257
250
bareiss_ops = ((M, i, j) -> nothing , myswaprows!,
258
251
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
259
- rank2, = bareiss! (M, bareiss_ops; find_pivot = find_and_record_pivot)
260
- rank1 = something (rank1, rank2)
261
- (rank1, rank2, pivots)
252
+ rank3, = bareiss! (M, bareiss_ops; find_pivot = find_and_record_pivot)
253
+ rank1 = something (rank1, rank3)
254
+ rank2 = something (rank2, rank3)
255
+ (rank1, rank2, rank3, pivots)
262
256
end
263
257
264
258
return mm, solvable_variables, do_bareiss! (mm, mm_orig)
@@ -272,27 +266,16 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
272
266
# variables`.
273
267
#
274
268
# `do_bareiss` conceptually gives us this system:
275
- # rank1 | [ M₁₁ M₁₂ | M₁₃ ] [v₁] = [0]
276
- # rank2 | [ 0 M₂₂ | M₂₃ ] P [v₂] = [0]
269
+ # rank1 | [ M₁₁ M₁₂ | M₁₃ M₁₄ ] [v₁] = [0]
270
+ # rank2 | [ 0 M₂₂ | M₂₃ M₂₄ ] P [v₂] = [0]
277
271
# -------------------|------------------------
278
- # [ 0 0 | 0 ] [v₃] = [0]
279
- #
280
- # Where `v₁` are the purely linear variables (i.e. those that only appear in linear equations),
281
- # `v₂` are the variables that may be potentially solved by the linear system and v₃ are the variables
282
- # that contribute to the equations, but are not solved by the linear system. Note
283
- # that the complete system may be larger than the linear subsystem and include variables
284
- # that do not appear here.
285
- mm, solvable_variables, (rank1, rank2, pivots) = aag_bareiss! (graph, var_to_diff,
286
- mm_orig)
272
+ # rank3 | [ 0 0 | M₃₃ M₃₄ ] [v₃] = [0]
273
+ # [ 0 0 | 0 0 ] [v₄] = [0]
274
+ mm, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss! (graph, var_to_diff,
275
+ mm_orig)
287
276
288
277
# Step 2: Simplify the system using the Bareiss factorization
289
-
290
278
ag = AliasGraph (size (mm, 2 ))
291
-
292
- # First, eliminate variables that only appear in linear equations and were removed
293
- # completely from the coefficient matrix. These are technically singularities in
294
- # the matrix, but assigning them to 0 is a feasible assignment and works well in
295
- # practice.
296
279
for v in setdiff (solvable_variables, @view pivots[1 : rank1])
297
280
ag[v] = 0
298
281
end
@@ -304,7 +287,11 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
304
287
function lss! (ei:: Integer )
305
288
vi = pivots[ei]
306
289
may_eliminate = true
307
- locally_structure_simplify! ((@view mm[ei, :]), vi, ag, var_to_diff)
290
+ for v in 𝑠neighbors (graph, mm. nzrows[ei])
291
+ # the differentiated variable cannot be eliminated
292
+ may_eliminate &= isnothing (diff_to_var[v]) && isnothing (var_to_diff[v])
293
+ end
294
+ locally_structure_simplify! ((@view mm[ei, :]), vi, ag, may_eliminate)
308
295
end
309
296
310
297
# Step 2.1: Go backwards, collecting eliminated variables and substituting
@@ -346,7 +333,7 @@ function exactdiv(a::Integer, b)
346
333
return d
347
334
end
348
335
349
- function locally_structure_simplify! (adj_row, pivot_col, ag, var_to_diff )
336
+ function locally_structure_simplify! (adj_row, pivot_col, ag, may_eliminate )
350
337
pivot_val = adj_row[pivot_col]
351
338
iszero (pivot_val) && return false
352
339
@@ -388,36 +375,21 @@ function locally_structure_simplify!(adj_row, pivot_col, ag, var_to_diff)
388
375
end
389
376
end
390
377
391
- if nirreducible <= 1
378
+ if may_eliminate && nirreducible <= 1
392
379
# There were only one or two terms left in the equation (including the
393
380
# pivot variable). We can eliminate the pivot variable.
394
381
#
395
382
# Note that when `nirreducible <= 1`, `alias_candidate` is uniquely
396
383
# determined.
397
384
if alias_candidate != = 0
398
- # Verify that the derivative depth of the variable is at least
399
- # as deep as that of the alias, otherwise, we can't eliminate.
400
- pivot_var = pivot_col
401
- alias_var = alias_candidate[2 ]
402
- while (pivot_var = var_to_diff[pivot_col]) != = nothing
403
- alias_var = var_to_diff[alias_var]
404
- alias_var === nothing && return false
405
- end
406
385
d, r = divrem (alias_candidate[1 ], pivot_val)
407
386
if r == 0 && (d == 1 || d == - 1 )
408
387
alias_candidate = - d => alias_candidate[2 ]
409
388
else
410
389
return false
411
390
end
412
391
end
413
- diff_alias_candidate (ac) = ac === 0 ? 0 : ac[1 ] => var_to_diff[ac[2 ]]
414
- while true
415
- @assert ! haskey (ag, pivot_col)
416
- ag[pivot_col] = alias_candidate
417
- pivot_col = var_to_diff[pivot_col]
418
- pivot_col === nothing && break
419
- alias_candidate = diff_alias_candidate (alias_candidate)
420
- end
392
+ ag[pivot_col] = alias_candidate
421
393
zero! (adj_row)
422
394
return true
423
395
end
0 commit comments