@@ -57,7 +57,13 @@ function alias_elimination(sys)
57
57
58
58
newstates = []
59
59
for j in eachindex (fullvars)
60
- if ! (j in keys (ag))
60
+ if j in keys (ag)
61
+ # Put back equations for alias eliminated dervars
62
+ if isdervar (state. structure, j) &&
63
+ ! (invview (state. structure. var_to_diff)[j] in keys (ag))
64
+ push! (eqs, fullvars[j] ~ subs[fullvars[j]])
65
+ end
66
+ else
61
67
isdervar (state. structure, j) || push! (newstates, fullvars[j])
62
68
end
63
69
end
@@ -192,6 +198,7 @@ struct AliasGraphKeySet <: AbstractSet{Int}
192
198
end
193
199
Base. keys (ag:: AliasGraph ) = AliasGraphKeySet (ag)
194
200
Base. iterate (agk:: AliasGraphKeySet , state... ) = Base. iterate (agk. ag. eliminated, state... )
201
+ Base. length (agk:: AliasGraphKeySet ) = Base. length (agk. ag. eliminated)
195
202
function Base. in (i:: Int , agk:: AliasGraphKeySet )
196
203
aliasto = agk. ag. aliasto
197
204
1 <= i <= length (aliasto) && aliasto[i] != = nothing
@@ -210,9 +217,11 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
210
217
is_linear_equations[e] = true
211
218
end
212
219
213
- # Variables that are highest order differentiated cannot be states of an ODE
214
- is_not_potential_state = isnothing .(var_to_diff)
215
- is_linear_variables = copy (is_not_potential_state)
220
+ # For now, only consider variables linear that are not differentiated.
221
+ # We could potentially apply the same logic to variables whose derivative
222
+ # is also linear, but that's a TODO .
223
+ diff_to_var = invview (var_to_diff)
224
+ is_linear_variables = .& (isnothing .(var_to_diff), isnothing .(diff_to_var))
216
225
for i in 𝑠vertices (graph)
217
226
is_linear_equations[i] && continue
218
227
for j in 𝑠neighbors (graph, i)
@@ -230,11 +239,9 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
230
239
r != = nothing && return r
231
240
rank1 = k - 1
232
241
end
233
- if rank2 === nothing
234
- r = find_masked_pivot (is_not_potential_state, M, k)
235
- r != = nothing && return r
236
- rank2 = k - 1
237
- end
242
+ # TODO : It would be better to sort the variables by
243
+ # derivative order here to enable more elimination
244
+ # opportunities.
238
245
return find_masked_pivot (nothing , M, k)
239
246
end
240
247
function find_and_record_pivot (M, k)
@@ -249,10 +256,9 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
249
256
end
250
257
bareiss_ops = ((M, i, j) -> nothing , myswaprows!,
251
258
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
252
- rank3, = bareiss! (M, bareiss_ops; find_pivot = find_and_record_pivot)
253
- rank1 = something (rank1, rank3)
254
- rank2 = something (rank2, rank3)
255
- (rank1, rank2, rank3, pivots)
259
+ rank2, = bareiss! (M, bareiss_ops; find_pivot = find_and_record_pivot)
260
+ rank1 = something (rank1, rank2)
261
+ (rank1, rank2, pivots)
256
262
end
257
263
258
264
return mm, solvable_variables, do_bareiss! (mm, mm_orig)
@@ -266,16 +272,27 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
266
272
# variables`.
267
273
#
268
274
# `do_bareiss` conceptually gives us this system:
269
- # rank1 | [ M₁₁ M₁₂ | M₁₃ M₁₄ ] [v₁] = [0]
270
- # rank2 | [ 0 M₂₂ | M₂₃ M₂₄ ] P [v₂] = [0]
275
+ # rank1 | [ M₁₁ M₁₂ | M₁₃ ] [v₁] = [0]
276
+ # rank2 | [ 0 M₂₂ | M₂₃ ] P [v₂] = [0]
271
277
# -------------------|------------------------
272
- # rank3 | [ 0 0 | M₃₃ M₃₄ ] [v₃] = [0]
273
- # [ 0 0 | 0 0 ] [v₄] = [0]
274
- mm, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss! (graph, var_to_diff,
275
- mm_orig)
278
+ # [ 0 0 | 0 ] [v₃] = [0]
279
+ #
280
+ # Where `v₁` are the purely linear variables (i.e. those that only appear in linear equations),
281
+ # `v₂` are the variables that may be potentially solved by the linear system and v₃ are the variables
282
+ # that contribute to the equations, but are not solved by the linear system. Note
283
+ # that the complete system may be larger than the linear subsystem and include variables
284
+ # that do not appear here.
285
+ mm, solvable_variables, (rank1, rank2, pivots) = aag_bareiss! (graph, var_to_diff,
286
+ mm_orig)
276
287
277
288
# Step 2: Simplify the system using the Bareiss factorization
289
+
278
290
ag = AliasGraph (size (mm, 2 ))
291
+
292
+ # First, eliminate variables that only appear in linear equations and were removed
293
+ # completely from the coefficient matrix. These are technically singularities in
294
+ # the matrix, but assigning them to 0 is a feasible assignment and works well in
295
+ # practice.
279
296
for v in setdiff (solvable_variables, @view pivots[1 : rank1])
280
297
ag[v] = 0
281
298
end
@@ -287,11 +304,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
287
304
function lss! (ei:: Integer )
288
305
vi = pivots[ei]
289
306
may_eliminate = true
290
- for v in 𝑠neighbors (graph, mm. nzrows[ei])
291
- # the differentiated variable cannot be eliminated
292
- may_eliminate &= isnothing (diff_to_var[v]) && isnothing (var_to_diff[v])
293
- end
294
- locally_structure_simplify! ((@view mm[ei, :]), vi, ag, may_eliminate)
307
+ locally_structure_simplify! ((@view mm[ei, :]), vi, ag, var_to_diff)
295
308
end
296
309
297
310
# Step 2.1: Go backwards, collecting eliminated variables and substituting
@@ -333,7 +346,7 @@ function exactdiv(a::Integer, b)
333
346
return d
334
347
end
335
348
336
- function locally_structure_simplify! (adj_row, pivot_col, ag, may_eliminate )
349
+ function locally_structure_simplify! (adj_row, pivot_col, ag, var_to_diff )
337
350
pivot_val = adj_row[pivot_col]
338
351
iszero (pivot_val) && return false
339
352
@@ -375,21 +388,36 @@ function locally_structure_simplify!(adj_row, pivot_col, ag, may_eliminate)
375
388
end
376
389
end
377
390
378
- if may_eliminate && nirreducible <= 1
391
+ if nirreducible <= 1
379
392
# There were only one or two terms left in the equation (including the
380
393
# pivot variable). We can eliminate the pivot variable.
381
394
#
382
395
# Note that when `nirreducible <= 1`, `alias_candidate` is uniquely
383
396
# determined.
384
397
if alias_candidate != = 0
398
+ # Verify that the derivative depth of the variable is at least
399
+ # as deep as that of the alias, otherwise, we can't eliminate.
400
+ pivot_var = pivot_col
401
+ alias_var = alias_candidate[2 ]
402
+ while (pivot_var = var_to_diff[pivot_col]) != = nothing
403
+ alias_var = var_to_diff[alias_var]
404
+ alias_var === nothing && return false
405
+ end
385
406
d, r = divrem (alias_candidate[1 ], pivot_val)
386
407
if r == 0 && (d == 1 || d == - 1 )
387
408
alias_candidate = - d => alias_candidate[2 ]
388
409
else
389
410
return false
390
411
end
391
412
end
392
- ag[pivot_col] = alias_candidate
413
+ diff_alias_candidate (ac) = ac === 0 ? 0 : ac[1 ] => var_to_diff[ac[2 ]]
414
+ while true
415
+ @assert ! haskey (ag, pivot_col)
416
+ ag[pivot_col] = alias_candidate
417
+ pivot_col = var_to_diff[pivot_col]
418
+ pivot_col === nothing && break
419
+ alias_candidate = diff_alias_candidate (alias_candidate)
420
+ end
393
421
zero! (adj_row)
394
422
return true
395
423
end
0 commit comments