Skip to content

Commit d45503c

Browse files
authored
Merge pull request #1635 from Keno/kf/morealias
Strengthen alias elimination
2 parents e104da1 + fcc6a7a commit d45503c

File tree

5 files changed

+122
-31
lines changed

5 files changed

+122
-31
lines changed

src/structural_transformation/bipartite_tearing/modia_tearing.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,57 @@ end
3737

3838
function tear_graph_modia(structure::SystemStructure; varfilter = v -> true,
3939
eqfilter = eq -> true)
40+
# It would be possible here to simply iterate over all variables and attempt to
41+
# use tearEquations! to produce a matching that greedily selects the minimal
42+
# number of torn variables. However, we can do this process faster if we first
43+
# compute the strongly connected components. In the absence of cycles and
44+
# non-solvability, a maximal matching on the original graph will give us an
45+
# optimal assignment. However, even with cycles, we can use the maximal matching
46+
# to give us a good starting point for a good matching and then proceed to
47+
# reverse edges in each scc to improve the solution. Note that it is possible
48+
# to have optimal solutions that cannot be found by this process. We will not
49+
# find them here [TODO: It would be good to have an explicit example of this.]
50+
4051
@unpack graph, solvable_graph = structure
4152
var_eq_matching = complete(maximal_matching(graph, eqfilter, varfilter))
4253
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, var_eq_matching)
4354

55+
# Here, we're using a maximal matching on the post-pantelides system to find
56+
# the strongly connected components of the system (of variables that depend
57+
# on each other). The strongly connected components are unique, however, the
58+
# maximal matching itself is not. Every maximal matching gives rise to the
59+
# same set of strongly connected components, but the associated equations need
60+
# not be the same. In the absence of solvability constraints, this may be a
61+
# small issue, but here it is possible that an equation got assigned to an
62+
# scc that cannot actually use it for solving a variable, but still precludes
63+
# another scc from using it. To avoid this, we delete any assignments that
64+
# are not in the solvable graph and extend the set of considered eqauations
65+
# below.
66+
for var in ndsts(solvable_graph)
67+
var_eq_matching[var] === unassigned && continue
68+
if !(BipartiteEdge(var, var_eq_matching[var]) in solvable_graph)
69+
var_eq_matching[var] = unassigned
70+
end
71+
end
72+
4473
for vars in var_sccs
4574
filtered_vars = filter(varfilter, vars)
4675
ieqs = Int[var_eq_matching[v]
4776
for v in filtered_vars if var_eq_matching[v] !== unassigned]
4877
for var in vars
4978
var_eq_matching[var] = unassigned
5079
end
80+
for var in filtered_vars
81+
# Add any equations that we may not have been able to use earlier to see
82+
# if a different matching may have been possible.
83+
for eq′ in 𝑑neighbors(solvable_graph, var)
84+
eqfilter(eq′) || continue
85+
eq′ in ieqs && continue
86+
if invview(var_eq_matching)[eq′] === unassigned
87+
push!(ieqs, eq′)
88+
end
89+
end
90+
end
5191
tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, ieqs, filtered_vars)
5292
end
5393

src/systems/alias_elimination.jl

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@ function alias_elimination(sys)
5757

5858
newstates = []
5959
for j in eachindex(fullvars)
60-
if !(j in keys(ag))
60+
if j in keys(ag)
61+
# Put back equations for alias eliminated dervars
62+
if isdervar(state.structure, j) &&
63+
!(invview(state.structure.var_to_diff)[j] in keys(ag))
64+
push!(eqs, fullvars[j] ~ subs[fullvars[j]])
65+
end
66+
else
6167
isdervar(state.structure, j) || push!(newstates, fullvars[j])
6268
end
6369
end
@@ -192,6 +198,7 @@ struct AliasGraphKeySet <: AbstractSet{Int}
192198
end
193199
Base.keys(ag::AliasGraph) = AliasGraphKeySet(ag)
194200
Base.iterate(agk::AliasGraphKeySet, state...) = Base.iterate(agk.ag.eliminated, state...)
201+
Base.length(agk::AliasGraphKeySet) = Base.length(agk.ag.eliminated)
195202
function Base.in(i::Int, agk::AliasGraphKeySet)
196203
aliasto = agk.ag.aliasto
197204
1 <= i <= length(aliasto) && aliasto[i] !== nothing
@@ -210,9 +217,11 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
210217
is_linear_equations[e] = true
211218
end
212219

213-
# Variables that are highest order differentiated cannot be states of an ODE
214-
is_not_potential_state = isnothing.(var_to_diff)
215-
is_linear_variables = copy(is_not_potential_state)
220+
# For now, only consider variables linear that are not differentiated.
221+
# We could potentially apply the same logic to variables whose derivative
222+
# is also linear, but that's a TODO.
223+
diff_to_var = invview(var_to_diff)
224+
is_linear_variables = .&(isnothing.(var_to_diff), isnothing.(diff_to_var))
216225
for i in 𝑠vertices(graph)
217226
is_linear_equations[i] && continue
218227
for j in 𝑠neighbors(graph, i)
@@ -230,11 +239,9 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
230239
r !== nothing && return r
231240
rank1 = k - 1
232241
end
233-
if rank2 === nothing
234-
r = find_masked_pivot(is_not_potential_state, M, k)
235-
r !== nothing && return r
236-
rank2 = k - 1
237-
end
242+
# TODO: It would be better to sort the variables by
243+
# derivative order here to enable more elimination
244+
# opportunities.
238245
return find_masked_pivot(nothing, M, k)
239246
end
240247
function find_and_record_pivot(M, k)
@@ -249,10 +256,9 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
249256
end
250257
bareiss_ops = ((M, i, j) -> nothing, myswaprows!,
251258
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
252-
rank3, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
253-
rank1 = something(rank1, rank3)
254-
rank2 = something(rank2, rank3)
255-
(rank1, rank2, rank3, pivots)
259+
rank2, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
260+
rank1 = something(rank1, rank2)
261+
(rank1, rank2, pivots)
256262
end
257263

258264
return mm, solvable_variables, do_bareiss!(mm, mm_orig)
@@ -266,16 +272,27 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
266272
# variables`.
267273
#
268274
# `do_bareiss` conceptually gives us this system:
269-
# rank1 | [ M₁₁ M₁₂ | M₁₃ M₁₄ ] [v₁] = [0]
270-
# rank2 | [ 0 M₂₂ | M₂₃ M₂₄ ] P [v₂] = [0]
275+
# rank1 | [ M₁₁ M₁₂ | M₁₃ ] [v₁] = [0]
276+
# rank2 | [ 0 M₂₂ | M₂₃ ] P [v₂] = [0]
271277
# -------------------|------------------------
272-
# rank3 | [ 0 0 | M₃₃ M₃₄ ] [v₃] = [0]
273-
# [ 0 0 | 0 0 ] [v₄] = [0]
274-
mm, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss!(graph, var_to_diff,
275-
mm_orig)
278+
# [ 0 0 | 0 ] [v₃] = [0]
279+
#
280+
# Where `v₁` are the purely linear variables (i.e. those that only appear in linear equations),
281+
# `v₂` are the variables that may be potentially solved by the linear system and v₃ are the variables
282+
# that contribute to the equations, but are not solved by the linear system. Note
283+
# that the complete system may be larger than the linear subsystem and include variables
284+
# that do not appear here.
285+
mm, solvable_variables, (rank1, rank2, pivots) = aag_bareiss!(graph, var_to_diff,
286+
mm_orig)
276287

277288
# Step 2: Simplify the system using the Bareiss factorization
289+
278290
ag = AliasGraph(size(mm, 2))
291+
292+
# First, eliminate variables that only appear in linear equations and were removed
293+
# completely from the coefficient matrix. These are technically singularities in
294+
# the matrix, but assigning them to 0 is a feasible assignment and works well in
295+
# practice.
279296
for v in setdiff(solvable_variables, @view pivots[1:rank1])
280297
ag[v] = 0
281298
end
@@ -287,11 +304,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
287304
function lss!(ei::Integer)
288305
vi = pivots[ei]
289306
may_eliminate = true
290-
for v in 𝑠neighbors(graph, mm.nzrows[ei])
291-
# the differentiated variable cannot be eliminated
292-
may_eliminate &= isnothing(diff_to_var[v]) && isnothing(var_to_diff[v])
293-
end
294-
locally_structure_simplify!((@view mm[ei, :]), vi, ag, may_eliminate)
307+
locally_structure_simplify!((@view mm[ei, :]), vi, ag, var_to_diff)
295308
end
296309

297310
# Step 2.1: Go backwards, collecting eliminated variables and substituting
@@ -333,7 +346,7 @@ function exactdiv(a::Integer, b)
333346
return d
334347
end
335348

336-
function locally_structure_simplify!(adj_row, pivot_col, ag, may_eliminate)
349+
function locally_structure_simplify!(adj_row, pivot_col, ag, var_to_diff)
337350
pivot_val = adj_row[pivot_col]
338351
iszero(pivot_val) && return false
339352

@@ -375,21 +388,36 @@ function locally_structure_simplify!(adj_row, pivot_col, ag, may_eliminate)
375388
end
376389
end
377390

378-
if may_eliminate && nirreducible <= 1
391+
if nirreducible <= 1
379392
# There were only one or two terms left in the equation (including the
380393
# pivot variable). We can eliminate the pivot variable.
381394
#
382395
# Note that when `nirreducible <= 1`, `alias_candidate` is uniquely
383396
# determined.
384397
if alias_candidate !== 0
398+
# Verify that the derivative depth of the variable is at least
399+
# as deep as that of the alias, otherwise, we can't eliminate.
400+
pivot_var = pivot_col
401+
alias_var = alias_candidate[2]
402+
while (pivot_var = var_to_diff[pivot_col]) !== nothing
403+
alias_var = var_to_diff[alias_var]
404+
alias_var === nothing && return false
405+
end
385406
d, r = divrem(alias_candidate[1], pivot_val)
386407
if r == 0 && (d == 1 || d == -1)
387408
alias_candidate = -d => alias_candidate[2]
388409
else
389410
return false
390411
end
391412
end
392-
ag[pivot_col] = alias_candidate
413+
diff_alias_candidate(ac) = ac === 0 ? 0 : ac[1] => var_to_diff[ac[2]]
414+
while true
415+
@assert !haskey(ag, pivot_col)
416+
ag[pivot_col] = alias_candidate
417+
pivot_col = var_to_diff[pivot_col]
418+
pivot_col === nothing && break
419+
alias_candidate = diff_alias_candidate(alias_candidate)
420+
end
393421
zero!(adj_row)
394422
return true
395423
end

src/systems/systemstructure.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,6 @@ function linear_subsys_adjmat(state::TransformationState)
348348
cadj = Vector{Int}[]
349349
coeffs = Int[]
350350
for (i, eq) in enumerate(eqs)
351-
isdiffeq(eq) && continue
352351
empty!(coeffs)
353352
linear_term = 0
354353
all_int_vars = true

test/odesystem.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -608,17 +608,18 @@ let
608608
@test sol[y] 0.9 * sol[x[1]] + sol[x[2]]
609609
@test isapprox(sol[x[1]][end], 1, atol = 1e-3)
610610

611-
prob = DAEProblem(sys, [D(y) => 0, D(x[1]) => 0, D(x[2]) => 0], Pair[x[1] => 0.5],
611+
prob = DAEProblem(sys, [D(y) => 0, D(x[1]) => 0, D(x[2]) => 0], [x[1] => 0.5],
612612
(0, 50))
613-
@test prob.u0 [0.5, 0]
613+
u0_dict = Dict(x[1] => 0.5, x[2] => 0.0)
614+
@test prob.u0 [u0_dict[x] for x in states(sys)]
614615
@test prob.du0 [0, 0]
615616
@test prob.p [1]
616617
sol = solve(prob, IDA())
617618
@test isapprox(sol[x[1]][end], 1, atol = 1e-3)
618619

619620
prob = DAEProblem(sys, [D(y) => 0, D(x[1]) => 0, D(x[2]) => 0], Pair[x[1] => 0.5],
620621
(0, 50), [k => 2])
621-
@test prob.u0 [0.5, 0]
622+
@test prob.u0 [u0_dict[x] for x in states(sys)]
622623
@test prob.du0 [0, 0]
623624
@test prob.p [2]
624625
sol = solve(prob, IDA())

test/reduction.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,26 @@ eqs = [D(x) ~ σ * (y - x)
246246
lorenz1 = ODESystem(eqs, t, name = :lorenz1)
247247
lorenz1_reduced = structural_simplify(lorenz1)
248248
@test z in Set(parameters(lorenz1_reduced))
249+
250+
# Test that alias elimination can propagate `x ~ 0` to derivatives
251+
@parameters t
252+
@variables x(t) y(t)
253+
254+
eqs = [x ~ 0
255+
D(x) ~ x + y]
256+
trivial0 = ODESystem(eqs, t, name = :trivial0)
257+
let trivial0 = alias_elimination(trivial0)
258+
# For symbolic systems, we currently don't let
259+
# alias elimination touch differential eqs, so
260+
# this leaves one equation left over. In theory,
261+
# the whole system would get eliminated.
262+
@test length(equations(trivial0)) <= 1
263+
@test length(states(trivial0)) <= 1
264+
end
265+
266+
eqs = [D(x) ~ 0]
267+
trivialconst = ODESystem(eqs, t, name = :trivial0)
268+
let trivialconst = alias_elimination(trivialconst)
269+
# Test that alias elimination doesn't eliminate a D(x) that is needed.
270+
@test length(equations(trivialconst)) == length(states(trivialconst)) == 1
271+
end

0 commit comments

Comments
 (0)