Skip to content

Commit 37596c4

Browse files
committed
Revert "Revert "Merge branch 'myb/pss' into myb/differential_alias""
This reverts commit 055fbb1.
1 parent ffa11b0 commit 37596c4

File tree

12 files changed

+550
-134
lines changed

12 files changed

+550
-134
lines changed

src/bipartite_graph.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ end
4949
function Matching(m::Int)
5050
Matching{Unassigned}(Union{Int, Unassigned}[unassigned for _ in 1:m], nothing)
5151
end
52+
function Matching{U}(m::Int) where {U}
53+
Matching{Union{Unassigned, U}}(Union{Int, Unassigned, U}[unassigned for _ in 1:m],
54+
nothing)
55+
end
5256

5357
Base.size(m::Matching) = Base.size(m.match)
5458
Base.getindex(m::Matching, i::Integer) = m.match[i]
@@ -65,9 +69,9 @@ function Base.setindex!(m::Matching{U}, v::Union{Integer, U}, i::Integer) where
6569
return m.match[i] = v
6670
end
6771

68-
function Base.push!(m::Matching{U}, v::Union{Integer, U}) where {U}
72+
function Base.push!(m::Matching, v)
6973
push!(m.match, v)
70-
if v !== unassigned && m.inv_match !== nothing
74+
if v isa Integer && m.inv_match !== nothing
7175
m.inv_match[v] = length(m.match)
7276
end
7377
end
@@ -346,8 +350,8 @@ vertices, subject to the constraint that vertices for which `srcfilter` or `dstf
346350
return `false` may not be matched.
347351
"""
348352
function maximal_matching(g::BipartiteGraph, srcfilter = vsrc -> true,
349-
dstfilter = vdst -> true)
350-
matching = Matching(ndsts(g))
353+
dstfilter = vdst -> true, ::Type{U} = Unassigned) where {U}
354+
matching = Matching{U}(ndsts(g))
351355
foreach(Iterators.filter(srcfilter, 𝑠vertices(g))) do vsrc
352356
construct_augmenting_path!(matching, g, vsrc, dstfilter)
353357
end

src/structural_transformation/bipartite_tearing/modia_tearing.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ function tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, eqs, va
3535
return nothing
3636
end
3737

38-
function tear_graph_modia(structure::SystemStructure; varfilter = v -> true,
39-
eqfilter = eq -> true)
38+
function tear_graph_modia(structure::SystemStructure, ::Type{U} = Unassigned;
39+
varfilter = v -> true, eqfilter = eq -> true) where {U}
4040
# It would be possible here to simply iterate over all variables and attempt to
4141
# use tearEquations! to produce a matching that greedily selects the minimal
4242
# number of torn variables. However, we can do this process faster if we first
@@ -49,7 +49,7 @@ function tear_graph_modia(structure::SystemStructure; varfilter = v -> true,
4949
# find them here [TODO: It would be good to have an explicit example of this.]
5050

5151
@unpack graph, solvable_graph = structure
52-
var_eq_matching = complete(maximal_matching(graph, eqfilter, varfilter))
52+
var_eq_matching = complete(maximal_matching(graph, eqfilter, varfilter, U))
5353
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, var_eq_matching)
5454

5555
for vars in var_sccs

src/structural_transformation/partial_state_selection.jl

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -146,53 +146,48 @@ function partial_state_selection_graph!(structure::SystemStructure, var_eq_match
146146
var_eq_matching
147147
end
148148

149-
function dummy_derivative_graph!(state::TransformationState, jac = nothing)
149+
function dummy_derivative_graph!(state::TransformationState, jac = nothing; kwargs...)
150+
state.structure.solvable_graph === nothing && find_solvables!(state; kwargs...)
150151
var_eq_matching = complete(pantelides!(state))
151152
complete!(state.structure)
152153
dummy_derivative_graph!(state.structure, var_eq_matching, jac)
153154
end
154155

155-
function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, jac)
156-
@unpack eq_to_diff, var_to_diff, graph = structure
157-
diff_to_eq = invview(eq_to_diff)
158-
diff_to_var = invview(var_to_diff)
159-
invgraph = invview(graph)
160-
161-
neqs = nsrcs(graph)
162-
eqlevel = zeros(Int, neqs)
156+
function compute_diff_level(diff_to_x)
157+
nxs = length(diff_to_x)
158+
xlevel = zeros(Int, nxs)
163159
maxlevel = 0
164-
for i in 1:neqs
160+
for i in 1:nxs
165161
level = 0
166-
eq = i
167-
while diff_to_eq[eq] !== nothing
168-
eq = diff_to_eq[eq]
162+
x = i
163+
while diff_to_x[x] !== nothing
164+
x = diff_to_x[x]
169165
level += 1
170166
end
171167
maxlevel = max(maxlevel, level)
172-
eqlevel[i] = level
168+
xlevel[i] = level
173169
end
170+
return xlevel, maxlevel
171+
end
174172

175-
nvars = ndsts(graph)
176-
varlevel = zeros(Int, nvars)
177-
for i in 1:nvars
178-
level = 0
179-
var = i
180-
while diff_to_var[var] !== nothing
181-
var = diff_to_var[var]
182-
level += 1
183-
end
184-
maxlevel = max(maxlevel, level)
185-
varlevel[i] = level
186-
end
173+
function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, jac)
174+
@unpack eq_to_diff, var_to_diff, graph = structure
175+
diff_to_eq = invview(eq_to_diff)
176+
diff_to_var = invview(var_to_diff)
177+
invgraph = invview(graph)
178+
179+
eqlevel, _ = compute_diff_level(diff_to_eq)
180+
varlevel, _ = compute_diff_level(diff_to_var)
187181

188182
var_sccs = find_var_sccs(graph, var_eq_matching)
189-
eqcolor = falses(neqs)
183+
eqcolor = falses(nsrcs(graph))
190184
dummy_derivatives = Int[]
191185
col_order = Int[]
186+
nvars = ndsts(graph)
192187
for vars in var_sccs
193188
eqs = [var_eq_matching[var] for var in vars if var_eq_matching[var] !== unassigned]
194189
isempty(eqs) && continue
195-
maxlevel = maximum(map(x -> eqlevel[x], eqs))
190+
maxlevel = maximum(Base.Fix1(getindex, eqlevel), eqs)
196191
iszero(maxlevel) && continue
197192

198193
rank_matching = Matching(nvars)
@@ -220,8 +215,10 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
220215
else
221216
rank = 0
222217
for var in vars
218+
# We need `invgraph` here because we are matching from
219+
# variables to equations.
223220
pathfound = construct_augmenting_path!(rank_matching, invgraph, var,
224-
eq -> eq in eqs_set, eqcolor)
221+
Base.Fix2(in, eqs_set), eqcolor)
225222
pathfound || continue
226223
push!(dummy_derivatives, var)
227224
rank += 1
@@ -239,5 +236,35 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
239236
end
240237
end
241238

242-
dummy_derivatives
239+
dummy_derivatives_set = BitSet(dummy_derivatives)
240+
# We can eliminate variables that are not a selected state (differential
241+
# variables). Selected states are differentiated variables that are not
242+
# dummy derivatives.
243+
can_eliminate = let var_to_diff = var_to_diff,
244+
dummy_derivatives_set = dummy_derivatives_set
245+
246+
v -> begin
247+
dv = var_to_diff[v]
248+
dv === nothing || dv in dummy_derivatives_set
249+
end
250+
end
251+
252+
# We don't want tearing to give us `y_t ~ D(y)`, so we skip equations with
253+
# actually differentiated variables.
254+
isdiffed = let diff_to_var = diff_to_var, dummy_derivatives_set = dummy_derivatives_set
255+
v -> diff_to_var[v] !== nothing && !(v in dummy_derivatives_set)
256+
end
257+
should_consider = let graph = graph, isdiffed = isdiffed
258+
eq -> !any(isdiffed, 𝑠neighbors(graph, eq))
259+
end
260+
261+
var_eq_matching = tear_graph_modia(structure, Union{Unassigned, SelectedState};
262+
varfilter = can_eliminate,
263+
eqfilter = should_consider)
264+
for v in eachindex(var_eq_matching)
265+
can_eliminate(v) && continue
266+
var_eq_matching[v] = SelectedState()
267+
end
268+
269+
return var_eq_matching
243270
end

0 commit comments

Comments
 (0)