Skip to content

Commit eba915f

Browse files
committed
Merge branch 'master' into controlsystem
2 parents b90f87d + d45503c commit eba915f

File tree

9 files changed

+344
-70
lines changed

9 files changed

+344
-70
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,13 @@ Unitful = "1.1"
7979
julia = "1.6"
8080

8181
[extras]
82+
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
8283
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
8384
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
85+
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
86+
Ipopt_jll = "9cc047cb-c261-5740-88fc-0cf96f7bdcc7"
8487
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
88+
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
8589
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
8690
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
8791
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -93,4 +97,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
9397
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9498

9599
[targets]
96-
test = ["BenchmarkTools", "ForwardDiff", "Optimization", "OptimizationOptimJL", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]
100+
test = ["AmplNLWriter", "BenchmarkTools", "ForwardDiff", "Ipopt", "Ipopt_jll", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]

src/bipartite_graph.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,53 @@ function complete(g::BipartiteGraph{I}) where {I}
184184
BipartiteGraph(g.ne, g.fadjlist, badjlist)
185185
end
186186

187+
# Matrix whose only purpose is to pretty-print the bipartite graph
188+
struct BipartiteAdjacencyList
189+
u::Union{Vector{Int}, Nothing}
190+
end
191+
function Base.show(io::IO, l::BipartiteAdjacencyList)
192+
if l.u === nothing
193+
printstyled(io, '', color = :light_black)
194+
elseif isempty(l.u)
195+
printstyled(io, '', color = :light_black)
196+
else
197+
print(io, l.u)
198+
end
199+
end
200+
201+
struct Label
202+
s::String
203+
end
204+
Base.show(io::IO, l::Label) = print(io, l.s)
205+
206+
struct BipartiteGraphPrintMatrix <:
207+
AbstractMatrix{Union{Label, Int, BipartiteAdjacencyList}}
208+
bpg::BipartiteGraph
209+
end
210+
Base.size(bgpm::BipartiteGraphPrintMatrix) = (max(nsrcs(bgpm.bpg), ndsts(bgpm.bpg)) + 1, 3)
211+
function Base.getindex(bgpm::BipartiteGraphPrintMatrix, i::Integer, j::Integer)
212+
checkbounds(bgpm, i, j)
213+
if i == 1
214+
return (Label.(("#", "src", "dst")))[j]
215+
elseif j == 1
216+
return i - 1
217+
elseif j == 2
218+
return BipartiteAdjacencyList(i - 1 <= nsrcs(bgpm.bpg) ?
219+
𝑠neighbors(bgpm.bpg, i - 1) : nothing)
220+
elseif j == 3
221+
return BipartiteAdjacencyList(i - 1 <= ndsts(bgpm.bpg) ?
222+
𝑑neighbors(bgpm.bpg, i - 1) : nothing)
223+
else
224+
@assert false
225+
end
226+
end
227+
228+
function Base.show(io::IO, b::BipartiteGraph)
229+
print(io, "BipartiteGraph with (", length(b.fadjlist), ", ",
230+
isa(b.badjlist, Int) ? b.badjlist : length(b.badjlist), ") (𝑠,𝑑)-vertices\n")
231+
Base.print_matrix(io, BipartiteGraphPrintMatrix(b))
232+
end
233+
187234
"""
188235
```julia
189236
Base.isequal(bg1::BipartiteGraph{T}, bg2::BipartiteGraph{T}) where {T<:Integer}

src/structural_transformation/bipartite_tearing/modia_tearing.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,57 @@ end
3737

3838
function tear_graph_modia(structure::SystemStructure; varfilter = v -> true,
3939
eqfilter = eq -> true)
40+
# It would be possible here to simply iterate over all variables and attempt to
41+
# use tearEquations! to produce a matching that greedily selects the minimal
42+
# number of torn variables. However, we can do this process faster if we first
43+
# compute the strongly connected components. In the absence of cycles and
44+
# non-solvability, a maximal matching on the original graph will give us an
45+
# optimal assignment. However, even with cycles, we can use the maximal matching
46+
# to give us a good starting point for a good matching and then proceed to
47+
# reverse edges in each scc to improve the solution. Note that it is possible
48+
# to have optimal solutions that cannot be found by this process. We will not
49+
# find them here [TODO: It would be good to have an explicit example of this.]
50+
4051
@unpack graph, solvable_graph = structure
4152
var_eq_matching = complete(maximal_matching(graph, eqfilter, varfilter))
4253
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, var_eq_matching)
4354

55+
# Here, we're using a maximal matching on the post-pantelides system to find
56+
# the strongly connected components of the system (of variables that depend
57+
# on each other). The strongly connected components are unique, however, the
58+
# maximal matching itself is not. Every maximal matching gives rise to the
59+
# same set of strongly connected components, but the associated equations need
60+
# not be the same. In the absence of solvability constraints, this may be a
61+
# small issue, but here it is possible that an equation got assigned to an
62+
# scc that cannot actually use it for solving a variable, but still precludes
63+
# another scc from using it. To avoid this, we delete any assignments that
64+
# are not in the solvable graph and extend the set of considered eqauations
65+
# below.
66+
for var in ndsts(solvable_graph)
67+
var_eq_matching[var] === unassigned && continue
68+
if !(BipartiteEdge(var, var_eq_matching[var]) in solvable_graph)
69+
var_eq_matching[var] = unassigned
70+
end
71+
end
72+
4473
for vars in var_sccs
4574
filtered_vars = filter(varfilter, vars)
4675
ieqs = Int[var_eq_matching[v]
4776
for v in filtered_vars if var_eq_matching[v] !== unassigned]
4877
for var in vars
4978
var_eq_matching[var] = unassigned
5079
end
80+
for var in filtered_vars
81+
# Add any equations that we may not have been able to use earlier to see
82+
# if a different matching may have been possible.
83+
for eq′ in 𝑑neighbors(solvable_graph, var)
84+
eqfilter(eq′) || continue
85+
eq′ in ieqs && continue
86+
if invview(var_eq_matching)[eq′] === unassigned
87+
push!(ieqs, eq′)
88+
end
89+
end
90+
end
5191
tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, ieqs, filtered_vars)
5292
end
5393

src/systems/alias_elimination.jl

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@ function alias_elimination(sys)
5757

5858
newstates = []
5959
for j in eachindex(fullvars)
60-
if !(j in keys(ag))
60+
if j in keys(ag)
61+
# Put back equations for alias eliminated dervars
62+
if isdervar(state.structure, j) &&
63+
!(invview(state.structure.var_to_diff)[j] in keys(ag))
64+
push!(eqs, fullvars[j] ~ subs[fullvars[j]])
65+
end
66+
else
6167
isdervar(state.structure, j) || push!(newstates, fullvars[j])
6268
end
6369
end
@@ -192,6 +198,7 @@ struct AliasGraphKeySet <: AbstractSet{Int}
192198
end
193199
Base.keys(ag::AliasGraph) = AliasGraphKeySet(ag)
194200
Base.iterate(agk::AliasGraphKeySet, state...) = Base.iterate(agk.ag.eliminated, state...)
201+
Base.length(agk::AliasGraphKeySet) = Base.length(agk.ag.eliminated)
195202
function Base.in(i::Int, agk::AliasGraphKeySet)
196203
aliasto = agk.ag.aliasto
197204
1 <= i <= length(aliasto) && aliasto[i] !== nothing
@@ -210,9 +217,11 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
210217
is_linear_equations[e] = true
211218
end
212219

213-
# Variables that are highest order differentiated cannot be states of an ODE
214-
is_not_potential_state = isnothing.(var_to_diff)
215-
is_linear_variables = copy(is_not_potential_state)
220+
# For now, only consider variables linear that are not differentiated.
221+
# We could potentially apply the same logic to variables whose derivative
222+
# is also linear, but that's a TODO.
223+
diff_to_var = invview(var_to_diff)
224+
is_linear_variables = .&(isnothing.(var_to_diff), isnothing.(diff_to_var))
216225
for i in 𝑠vertices(graph)
217226
is_linear_equations[i] && continue
218227
for j in 𝑠neighbors(graph, i)
@@ -230,11 +239,9 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
230239
r !== nothing && return r
231240
rank1 = k - 1
232241
end
233-
if rank2 === nothing
234-
r = find_masked_pivot(is_not_potential_state, M, k)
235-
r !== nothing && return r
236-
rank2 = k - 1
237-
end
242+
# TODO: It would be better to sort the variables by
243+
# derivative order here to enable more elimination
244+
# opportunities.
238245
return find_masked_pivot(nothing, M, k)
239246
end
240247
function find_and_record_pivot(M, k)
@@ -249,10 +256,9 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
249256
end
250257
bareiss_ops = ((M, i, j) -> nothing, myswaprows!,
251258
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
252-
rank3, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
253-
rank1 = something(rank1, rank3)
254-
rank2 = something(rank2, rank3)
255-
(rank1, rank2, rank3, pivots)
259+
rank2, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
260+
rank1 = something(rank1, rank2)
261+
(rank1, rank2, pivots)
256262
end
257263

258264
return mm, solvable_variables, do_bareiss!(mm, mm_orig)
@@ -266,16 +272,27 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
266272
# variables`.
267273
#
268274
# `do_bareiss` conceptually gives us this system:
269-
# rank1 | [ M₁₁ M₁₂ | M₁₃ M₁₄ ] [v₁] = [0]
270-
# rank2 | [ 0 M₂₂ | M₂₃ M₂₄ ] P [v₂] = [0]
275+
# rank1 | [ M₁₁ M₁₂ | M₁₃ ] [v₁] = [0]
276+
# rank2 | [ 0 M₂₂ | M₂₃ ] P [v₂] = [0]
271277
# -------------------|------------------------
272-
# rank3 | [ 0 0 | M₃₃ M₃₄ ] [v₃] = [0]
273-
# [ 0 0 | 0 0 ] [v₄] = [0]
274-
mm, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss!(graph, var_to_diff,
275-
mm_orig)
278+
# [ 0 0 | 0 ] [v₃] = [0]
279+
#
280+
# Where `v₁` are the purely linear variables (i.e. those that only appear in linear equations),
281+
# `v₂` are the variables that may be potentially solved by the linear system and v₃ are the variables
282+
# that contribute to the equations, but are not solved by the linear system. Note
283+
# that the complete system may be larger than the linear subsystem and include variables
284+
# that do not appear here.
285+
mm, solvable_variables, (rank1, rank2, pivots) = aag_bareiss!(graph, var_to_diff,
286+
mm_orig)
276287

277288
# Step 2: Simplify the system using the Bareiss factorization
289+
278290
ag = AliasGraph(size(mm, 2))
291+
292+
# First, eliminate variables that only appear in linear equations and were removed
293+
# completely from the coefficient matrix. These are technically singularities in
294+
# the matrix, but assigning them to 0 is a feasible assignment and works well in
295+
# practice.
279296
for v in setdiff(solvable_variables, @view pivots[1:rank1])
280297
ag[v] = 0
281298
end
@@ -287,11 +304,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
287304
function lss!(ei::Integer)
288305
vi = pivots[ei]
289306
may_eliminate = true
290-
for v in 𝑠neighbors(graph, mm.nzrows[ei])
291-
# the differentiated variable cannot be eliminated
292-
may_eliminate &= isnothing(diff_to_var[v]) && isnothing(var_to_diff[v])
293-
end
294-
locally_structure_simplify!((@view mm[ei, :]), vi, ag, may_eliminate)
307+
locally_structure_simplify!((@view mm[ei, :]), vi, ag, var_to_diff)
295308
end
296309

297310
# Step 2.1: Go backwards, collecting eliminated variables and substituting
@@ -333,7 +346,7 @@ function exactdiv(a::Integer, b)
333346
return d
334347
end
335348

336-
function locally_structure_simplify!(adj_row, pivot_col, ag, may_eliminate)
349+
function locally_structure_simplify!(adj_row, pivot_col, ag, var_to_diff)
337350
pivot_val = adj_row[pivot_col]
338351
iszero(pivot_val) && return false
339352

@@ -375,21 +388,36 @@ function locally_structure_simplify!(adj_row, pivot_col, ag, may_eliminate)
375388
end
376389
end
377390

378-
if may_eliminate && nirreducible <= 1
391+
if nirreducible <= 1
379392
# There were only one or two terms left in the equation (including the
380393
# pivot variable). We can eliminate the pivot variable.
381394
#
382395
# Note that when `nirreducible <= 1`, `alias_candidate` is uniquely
383396
# determined.
384397
if alias_candidate !== 0
398+
# Verify that the derivative depth of the variable is at least
399+
# as deep as that of the alias, otherwise, we can't eliminate.
400+
pivot_var = pivot_col
401+
alias_var = alias_candidate[2]
402+
while (pivot_var = var_to_diff[pivot_col]) !== nothing
403+
alias_var = var_to_diff[alias_var]
404+
alias_var === nothing && return false
405+
end
385406
d, r = divrem(alias_candidate[1], pivot_val)
386407
if r == 0 && (d == 1 || d == -1)
387408
alias_candidate = -d => alias_candidate[2]
388409
else
389410
return false
390411
end
391412
end
392-
ag[pivot_col] = alias_candidate
413+
diff_alias_candidate(ac) = ac === 0 ? 0 : ac[1] => var_to_diff[ac[2]]
414+
while true
415+
@assert !haskey(ag, pivot_col)
416+
ag[pivot_col] = alias_candidate
417+
pivot_col = var_to_diff[pivot_col]
418+
pivot_col === nothing && break
419+
alias_candidate = diff_alias_candidate(alias_candidate)
420+
end
393421
zero!(adj_row)
394422
return true
395423
end

0 commit comments

Comments
 (0)