Skip to content

Commit 869a825

Browse files
authored
Merge pull request #2274 from SciML/myb/lin
Better tearing diagnostics info
2 parents 68dfae6 + f171537 commit 869a825

File tree

3 files changed

+43
-40
lines changed

3 files changed

+43
-40
lines changed

src/bipartite_graph.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,8 @@ end
733733

734734
Graphs.has_edge(g::DiCMOBiGraph{true}, a, b) = a in inneighbors(g, b)
735735
Graphs.has_edge(g::DiCMOBiGraph{false}, a, b) = b in outneighbors(g, a)
736+
# This definition is required for `induced_subgraph` to work
737+
(::Type{<:DiCMOBiGraph})(n::Integer) = SimpleDiGraph(n)
736738

737739
# Condensation Graphs
738740
abstract type AbstractCondensationGraph <: AbstractGraph{Int} end

src/structural_transformation/bipartite_tearing/modia_tearing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
8383
max(length(var_eq_matching),
8484
maximum(x -> x isa Int ? x : 0, var_eq_matching)))
8585
full_var_eq_matching = copy(var_eq_matching)
86-
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, var_eq_matching)
86+
var_sccs = find_var_sccs(graph, var_eq_matching)
8787
vargraph = DiCMOBiGraph{true}(graph)
8888
ict = IncrementalCycleTracker(vargraph; dir = :in)
8989

@@ -111,5 +111,5 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
111111
empty!(ieqs)
112112
empty!(filtered_vars)
113113
end
114-
return var_eq_matching, full_var_eq_matching
114+
return var_eq_matching, full_var_eq_matching, var_sccs
115115
end

src/structural_transformation/partial_state_selection.jl

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -299,59 +299,60 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
299299
(n_dummys = length(dummy_derivatives))
300300
@warn "The number of dummy derivatives ($n_dummys) does not match the number of differentiated equations ($n_diff_eqs)."
301301
end
302-
dummy_derivatives_set = BitSet(dummy_derivatives)
303302

304-
is_not_present_non_rec = let graph = graph
305-
v -> isempty(𝑑neighbors(graph, v))
303+
ret = tearing_with_dummy_derivatives(structure, BitSet(dummy_derivatives))
304+
if log
305+
ret
306+
else
307+
ret[1]
306308
end
309+
end
307310

308-
is_not_present = let var_to_diff = var_to_diff
309-
v -> while true
310-
# if a higher derivative is present, then it's present
311-
is_not_present_non_rec(v) || return false
312-
v = var_to_diff[v]
313-
v === nothing && return true
314-
end
311+
function is_present(structure, v)::Bool
312+
@unpack var_to_diff, graph = structure
313+
while true
314+
# if a higher derivative is present, then it's present
315+
isempty(𝑑neighbors(graph, v)) || return true
316+
v = var_to_diff[v]
317+
v === nothing && return false
315318
end
319+
end
316320

317-
# Derivatives that are either in the dummy derivatives set or ended up not
318-
# participating in the system at all are not considered differential
319-
is_some_diff = let dummy_derivatives_set = dummy_derivatives_set
320-
v -> !(v in dummy_derivatives_set) && !is_not_present(v)
321-
end
321+
# Derivatives that are either in the dummy derivatives set or ended up not
322+
# participating in the system at all are not considered differential
323+
function is_some_diff(structure, dummy_derivatives, v)::Bool
324+
!(v in dummy_derivatives) && is_present(structure, v)
325+
end
322326

323-
# We don't want tearing to give us `y_t ~ D(y)`, so we skip equations with
324-
# actually differentiated variables.
325-
isdiffed = let diff_to_var = diff_to_var
326-
v -> diff_to_var[v] !== nothing && is_some_diff(v)
327-
end
327+
# We don't want tearing to give us `y_t ~ D(y)`, so we skip equations with
328+
# actually differentiated variables.
329+
function isdiffed((structure, dummy_derivatives), v)::Bool
330+
@unpack var_to_diff, graph = structure
331+
diff_to_var = invview(var_to_diff)
332+
diff_to_var[v] !== nothing && is_some_diff(structure, dummy_derivatives, v)
333+
end
328334

335+
function tearing_with_dummy_derivatives(structure, dummy_derivatives)
336+
@unpack var_to_diff = structure
329337
# We can eliminate variables that are not a selected state (differential
330338
# variables). Selected states are differentiated variables that are not
331339
# dummy derivatives.
332-
can_eliminate = let var_to_diff = var_to_diff
333-
v -> begin
334-
dv = var_to_diff[v]
335-
dv === nothing && return true
336-
is_some_diff(dv) || return true
337-
return false
340+
can_eliminate = falses(length(var_to_diff))
341+
for (v, dv) in enumerate(var_to_diff)
342+
dv = var_to_diff[v]
343+
if dv === nothing || !is_some_diff(structure, dummy_derivatives, dv)
344+
can_eliminate[v] = true
338345
end
339346
end
340-
341-
var_eq_matching, full_var_eq_matching = tear_graph_modia(structure, isdiffed,
347+
var_eq_matching, full_var_eq_matching, var_sccs = tear_graph_modia(structure,
348+
Base.Fix1(isdiffed, (structure, dummy_derivatives)),
342349
Union{Unassigned, SelectedState};
343-
varfilter = can_eliminate)
350+
varfilter = Base.Fix1(getindex, can_eliminate))
344351
for v in eachindex(var_eq_matching)
345-
is_not_present(v) && continue
352+
is_present(structure, v) || continue
346353
dv = var_to_diff[v]
347-
(dv === nothing || !is_some_diff(dv)) && continue
354+
(dv === nothing || !is_some_diff(structure, dummy_derivatives, dv)) && continue
348355
var_eq_matching[v] = SelectedState()
349356
end
350-
351-
if log
352-
candidates = can_eliminate.(1:ndsts(graph))
353-
return var_eq_matching, full_var_eq_matching, candidates
354-
else
355-
return var_eq_matching
356-
end
357+
return var_eq_matching, full_var_eq_matching, var_sccs, can_eliminate
357358
end

0 commit comments

Comments
 (0)