Skip to content

Commit bdc7de9

Browse files
authored
Merge pull request #1677 from SciML/revert-1635-kf/morealias
Revert "Strengthen alias elimination"
2 parents 7557040 + c13442e commit bdc7de9

File tree

5 files changed

+31
-122
lines changed

5 files changed

+31
-122
lines changed

src/structural_transformation/bipartite_tearing/modia_tearing.jl

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -37,57 +37,17 @@ end
3737

3838
function tear_graph_modia(structure::SystemStructure; varfilter = v -> true,
3939
eqfilter = eq -> true)
40-
# It would be possible here to simply iterate over all variables and attempt to
41-
# use tearEquations! to produce a matching that greedily selects the minimal
42-
# number of torn variables. However, we can do this process faster if we first
43-
# compute the strongly connected components. In the absence of cycles and
44-
# non-solvability, a maximal matching on the original graph will give us an
45-
# optimal assignment. However, even with cycles, we can use the maximal matching
46-
# to give us a good starting point for a good matching and then proceed to
47-
# reverse edges in each scc to improve the solution. Note that it is possible
48-
# to have optimal solutions that cannot be found by this process. We will not
49-
# find them here [TODO: It would be good to have an explicit example of this.]
50-
5140
@unpack graph, solvable_graph = structure
5241
var_eq_matching = complete(maximal_matching(graph, eqfilter, varfilter))
5342
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, var_eq_matching)
5443

55-
# Here, we're using a maximal matching on the post-pantelides system to find
56-
# the strongly connected components of the system (of variables that depend
57-
# on each other). The strongly connected components are unique, however, the
58-
# maximal matching itself is not. Every maximal matching gives rise to the
59-
# same set of strongly connected components, but the associated equations need
60-
# not be the same. In the absence of solvability constraints, this may be a
61-
# small issue, but here it is possible that an equation got assigned to an
62-
# scc that cannot actually use it for solving a variable, but still precludes
63-
# another scc from using it. To avoid this, we delete any assignments that
64-
# are not in the solvable graph and extend the set of considered eqauations
65-
# below.
66-
for var in ndsts(solvable_graph)
67-
var_eq_matching[var] === unassigned && continue
68-
if !(BipartiteEdge(var, var_eq_matching[var]) in solvable_graph)
69-
var_eq_matching[var] = unassigned
70-
end
71-
end
72-
7344
for vars in var_sccs
7445
filtered_vars = filter(varfilter, vars)
7546
ieqs = Int[var_eq_matching[v]
7647
for v in filtered_vars if var_eq_matching[v] !== unassigned]
7748
for var in vars
7849
var_eq_matching[var] = unassigned
7950
end
80-
for var in filtered_vars
81-
# Add any equations that we may not have been able to use earlier to see
82-
# if a different matching may have been possible.
83-
for eq′ in 𝑑neighbors(solvable_graph, var)
84-
eqfilter(eq′) || continue
85-
eq′ in ieqs && continue
86-
if invview(var_eq_matching)[eq′] === unassigned
87-
push!(ieqs, eq′)
88-
end
89-
end
90-
end
9151
tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, ieqs, filtered_vars)
9252
end
9353

src/systems/alias_elimination.jl

Lines changed: 27 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,7 @@ function alias_elimination(sys)
5757

5858
newstates = []
5959
for j in eachindex(fullvars)
60-
if j in keys(ag)
61-
# Put back equations for alias eliminated dervars
62-
if isdervar(state.structure, j) &&
63-
!(invview(state.structure.var_to_diff)[j] in keys(ag))
64-
push!(eqs, fullvars[j] ~ subs[fullvars[j]])
65-
end
66-
else
60+
if !(j in keys(ag))
6761
isdervar(state.structure, j) || push!(newstates, fullvars[j])
6862
end
6963
end
@@ -198,7 +192,6 @@ struct AliasGraphKeySet <: AbstractSet{Int}
198192
end
199193
Base.keys(ag::AliasGraph) = AliasGraphKeySet(ag)
200194
Base.iterate(agk::AliasGraphKeySet, state...) = Base.iterate(agk.ag.eliminated, state...)
201-
Base.length(agk::AliasGraphKeySet) = Base.length(agk.ag.eliminated)
202195
function Base.in(i::Int, agk::AliasGraphKeySet)
203196
aliasto = agk.ag.aliasto
204197
1 <= i <= length(aliasto) && aliasto[i] !== nothing
@@ -217,11 +210,9 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
217210
is_linear_equations[e] = true
218211
end
219212

220-
# For now, only consider variables linear that are not differentiated.
221-
# We could potentially apply the same logic to variables whose derivative
222-
# is also linear, but that's a TODO.
223-
diff_to_var = invview(var_to_diff)
224-
is_linear_variables = .&(isnothing.(var_to_diff), isnothing.(diff_to_var))
213+
# Variables that are highest order differentiated cannot be states of an ODE
214+
is_not_potential_state = isnothing.(var_to_diff)
215+
is_linear_variables = copy(is_not_potential_state)
225216
for i in 𝑠vertices(graph)
226217
is_linear_equations[i] && continue
227218
for j in 𝑠neighbors(graph, i)
@@ -239,9 +230,11 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
239230
r !== nothing && return r
240231
rank1 = k - 1
241232
end
242-
# TODO: It would be better to sort the variables by
243-
# derivative order here to enable more elimination
244-
# opportunities.
233+
if rank2 === nothing
234+
r = find_masked_pivot(is_not_potential_state, M, k)
235+
r !== nothing && return r
236+
rank2 = k - 1
237+
end
245238
return find_masked_pivot(nothing, M, k)
246239
end
247240
function find_and_record_pivot(M, k)
@@ -256,9 +249,10 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
256249
end
257250
bareiss_ops = ((M, i, j) -> nothing, myswaprows!,
258251
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
259-
rank2, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
260-
rank1 = something(rank1, rank2)
261-
(rank1, rank2, pivots)
252+
rank3, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
253+
rank1 = something(rank1, rank3)
254+
rank2 = something(rank2, rank3)
255+
(rank1, rank2, rank3, pivots)
262256
end
263257

264258
return mm, solvable_variables, do_bareiss!(mm, mm_orig)
@@ -272,27 +266,16 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
272266
# variables`.
273267
#
274268
# `do_bareiss` conceptually gives us this system:
275-
# rank1 | [ M₁₁ M₁₂ | M₁₃ ] [v₁] = [0]
276-
# rank2 | [ 0 M₂₂ | M₂₃ ] P [v₂] = [0]
269+
# rank1 | [ M₁₁ M₁₂ | M₁₃ M₁₄ ] [v₁] = [0]
270+
# rank2 | [ 0 M₂₂ | M₂₃ M₂₄ ] P [v₂] = [0]
277271
# -------------------|------------------------
278-
# [ 0 0 | 0 ] [v₃] = [0]
279-
#
280-
# Where `v₁` are the purely linear variables (i.e. those that only appear in linear equations),
281-
# `v₂` are the variables that may be potentially solved by the linear system and v₃ are the variables
282-
# that contribute to the equations, but are not solved by the linear system. Note
283-
# that the complete system may be larger than the linear subsystem and include variables
284-
# that do not appear here.
285-
mm, solvable_variables, (rank1, rank2, pivots) = aag_bareiss!(graph, var_to_diff,
286-
mm_orig)
272+
# rank3 | [ 0 0 | M₃₃ M₃₄ ] [v₃] = [0]
273+
# [ 0 0 | 0 0 ] [v₄] = [0]
274+
mm, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss!(graph, var_to_diff,
275+
mm_orig)
287276

288277
# Step 2: Simplify the system using the Bareiss factorization
289-
290278
ag = AliasGraph(size(mm, 2))
291-
292-
# First, eliminate variables that only appear in linear equations and were removed
293-
# completely from the coefficient matrix. These are technically singularities in
294-
# the matrix, but assigning them to 0 is a feasible assignment and works well in
295-
# practice.
296279
for v in setdiff(solvable_variables, @view pivots[1:rank1])
297280
ag[v] = 0
298281
end
@@ -304,7 +287,11 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
304287
function lss!(ei::Integer)
305288
vi = pivots[ei]
306289
may_eliminate = true
307-
locally_structure_simplify!((@view mm[ei, :]), vi, ag, var_to_diff)
290+
for v in 𝑠neighbors(graph, mm.nzrows[ei])
291+
# the differentiated variable cannot be eliminated
292+
may_eliminate &= isnothing(diff_to_var[v]) && isnothing(var_to_diff[v])
293+
end
294+
locally_structure_simplify!((@view mm[ei, :]), vi, ag, may_eliminate)
308295
end
309296

310297
# Step 2.1: Go backwards, collecting eliminated variables and substituting
@@ -346,7 +333,7 @@ function exactdiv(a::Integer, b)
346333
return d
347334
end
348335

349-
function locally_structure_simplify!(adj_row, pivot_col, ag, var_to_diff)
336+
function locally_structure_simplify!(adj_row, pivot_col, ag, may_eliminate)
350337
pivot_val = adj_row[pivot_col]
351338
iszero(pivot_val) && return false
352339

@@ -388,36 +375,21 @@ function locally_structure_simplify!(adj_row, pivot_col, ag, var_to_diff)
388375
end
389376
end
390377

391-
if nirreducible <= 1
378+
if may_eliminate && nirreducible <= 1
392379
# There were only one or two terms left in the equation (including the
393380
# pivot variable). We can eliminate the pivot variable.
394381
#
395382
# Note that when `nirreducible <= 1`, `alias_candidate` is uniquely
396383
# determined.
397384
if alias_candidate !== 0
398-
# Verify that the derivative depth of the variable is at least
399-
# as deep as that of the alias, otherwise, we can't eliminate.
400-
pivot_var = pivot_col
401-
alias_var = alias_candidate[2]
402-
while (pivot_var = var_to_diff[pivot_col]) !== nothing
403-
alias_var = var_to_diff[alias_var]
404-
alias_var === nothing && return false
405-
end
406385
d, r = divrem(alias_candidate[1], pivot_val)
407386
if r == 0 && (d == 1 || d == -1)
408387
alias_candidate = -d => alias_candidate[2]
409388
else
410389
return false
411390
end
412391
end
413-
diff_alias_candidate(ac) = ac === 0 ? 0 : ac[1] => var_to_diff[ac[2]]
414-
while true
415-
@assert !haskey(ag, pivot_col)
416-
ag[pivot_col] = alias_candidate
417-
pivot_col = var_to_diff[pivot_col]
418-
pivot_col === nothing && break
419-
alias_candidate = diff_alias_candidate(alias_candidate)
420-
end
392+
ag[pivot_col] = alias_candidate
421393
zero!(adj_row)
422394
return true
423395
end

src/systems/systemstructure.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ function linear_subsys_adjmat(state::TransformationState)
348348
cadj = Vector{Int}[]
349349
coeffs = Int[]
350350
for (i, eq) in enumerate(eqs)
351+
isdiffeq(eq) && continue
351352
empty!(coeffs)
352353
linear_term = 0
353354
all_int_vars = true

test/odesystem.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -608,18 +608,17 @@ let
608608
@test sol[y] 0.9 * sol[x[1]] + sol[x[2]]
609609
@test isapprox(sol[x[1]][end], 1, atol = 1e-3)
610610

611-
prob = DAEProblem(sys, [D(y) => 0, D(x[1]) => 0, D(x[2]) => 0], [x[1] => 0.5],
611+
prob = DAEProblem(sys, [D(y) => 0, D(x[1]) => 0, D(x[2]) => 0], Pair[x[1] => 0.5],
612612
(0, 50))
613-
u0_dict = Dict(x[1] => 0.5, x[2] => 0.0)
614-
@test prob.u0 [u0_dict[x] for x in states(sys)]
613+
@test prob.u0 [0.5, 0]
615614
@test prob.du0 [0, 0]
616615
@test prob.p [1]
617616
sol = solve(prob, IDA())
618617
@test isapprox(sol[x[1]][end], 1, atol = 1e-3)
619618

620619
prob = DAEProblem(sys, [D(y) => 0, D(x[1]) => 0, D(x[2]) => 0], Pair[x[1] => 0.5],
621620
(0, 50), [k => 2])
622-
@test prob.u0 [u0_dict[x] for x in states(sys)]
621+
@test prob.u0 [0.5, 0]
623622
@test prob.du0 [0, 0]
624623
@test prob.p [2]
625624
sol = solve(prob, IDA())

test/reduction.jl

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -246,26 +246,3 @@ eqs = [D(x) ~ σ * (y - x)
246246
lorenz1 = ODESystem(eqs, t, name = :lorenz1)
247247
lorenz1_reduced = structural_simplify(lorenz1)
248248
@test z in Set(parameters(lorenz1_reduced))
249-
250-
# Test that alias elimination can propagate `x ~ 0` to derivatives
251-
@parameters t
252-
@variables x(t) y(t)
253-
254-
eqs = [x ~ 0
255-
D(x) ~ x + y]
256-
trivial0 = ODESystem(eqs, t, name = :trivial0)
257-
let trivial0 = alias_elimination(trivial0)
258-
# For symbolic systems, we currently don't let
259-
# alias elimination touch differential eqs, so
260-
# this leaves one equation left over. In theory,
261-
# the whole system would get eliminated.
262-
@test length(equations(trivial0)) <= 1
263-
@test length(states(trivial0)) <= 1
264-
end
265-
266-
eqs = [D(x) ~ 0]
267-
trivialconst = ODESystem(eqs, t, name = :trivial0)
268-
let trivialconst = alias_elimination(trivialconst)
269-
# Test that alias elimination doesn't eliminate a D(x) that is needed.
270-
@test length(equations(trivialconst)) == length(states(trivialconst)) == 1
271-
end

0 commit comments

Comments
 (0)