Skip to content

Commit 528e663

Browse files
authored
Merge pull request #2148 from SciML/myb/generic_consistency
Make consistency check more generic
2 parents 6c31043 + ad0f802 commit 528e663

File tree

4 files changed

+17
-8
lines changed

4 files changed

+17
-8
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
2323
IncrementalCycleTracker, add_edge_checked!, topological_sort,
2424
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
2525
AliasGraph, filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
26-
fast_substitute
26+
fast_substitute, get_fullvars, has_equations
2727

2828
using ModelingToolkit.BipartiteGraphs
2929
import .BipartiteGraphs: invview, complete

src/structural_transformation/utils.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,21 @@ function BipartiteGraphs.maximal_matching(s::SystemStructure, eqfilter = eq -> t
1212
maximal_matching(s.graph, eqfilter, varfilter)
1313
end
1414

15+
n_concrete_eqs(state::TransformationState) = n_concrete_eqs(state.structure)
16+
n_concrete_eqs(structure::SystemStructure) = n_concrete_eqs(structure.graph)
17+
function n_concrete_eqs(graph::BipartiteGraph)
18+
neqs = count(e -> !isempty(𝑠neighbors(graph, e)), 𝑠vertices(graph))
19+
end
20+
1521
function error_reporting(state, bad_idxs, n_highest_vars, iseqs, orig_inputs)
1622
io = IOBuffer()
17-
neqs = length(equations(state))
23+
neqs = n_concrete_eqs(state)
1824
if iseqs
1925
error_title = "More equations than variables, here are the potential extra equation(s):\n"
20-
out_arr = equations(state)[bad_idxs]
26+
out_arr = has_equations(state) ? equations(state)[bad_idxs] : bad_idxs
2127
else
2228
error_title = "More variables than equations, here are the potential extra variable(s):\n"
23-
out_arr = state.fullvars[bad_idxs]
29+
out_arr = get_fullvars(state)[bad_idxs]
2430
unset_inputs = intersect(out_arr, orig_inputs)
2531
n_missing_eqs = n_highest_vars - neqs
2632
n_unset_inputs = length(unset_inputs)
@@ -52,14 +58,14 @@ end
5258
###
5359
### Structural check
5460
###
55-
function check_consistency(state::TearingState, ag, orig_inputs)
56-
fullvars = state.fullvars
61+
function check_consistency(state::TransformationState, ag, orig_inputs)
62+
fullvars = get_fullvars(state)
63+
neqs = n_concrete_eqs(state)
5764
@unpack graph, var_to_diff = state.structure
5865
n_highest_vars = count(v -> var_to_diff[v] === nothing &&
5966
!isempty(𝑑neighbors(graph, v)) &&
6067
(ag === nothing || !haskey(ag, v) || ag[v] != v),
6168
vertices(var_to_diff))
62-
neqs = nsrcs(graph)
6369
is_balanced = n_highest_vars == neqs
6470

6571
if neqs > 0 && !is_balanced

src/systems/abstractsystem.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ for prop in [:eqs
238238
end
239239
end
240240

241+
has_equations(::AbstractSystem) = true
242+
241243
const EMPTY_TGRAD = Vector{Num}(undef, 0)
242244
const EMPTY_JAC = Matrix{Num}(undef, 0, 0)
243245
function invalidate_cache!(sys::AbstractSystem)

src/systems/systemstructure.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten,
1010
isparameter, isconstant,
1111
independent_variables, SparseMatrixCLIL, AbstractSystem,
1212
equations, isirreducible, input_timedomain, TimeDomain,
13-
VariableType, getvariabletype
13+
VariableType, getvariabletype, has_equations
1414
using ..BipartiteGraphs
1515
import ..BipartiteGraphs: invview, complete
1616
using Graphs
@@ -140,6 +140,7 @@ abstract type TransformationState{T} end
140140
abstract type AbstractTearingState{T} <: TransformationState{T} end
141141

142142
get_fullvars(ts::TransformationState) = ts.fullvars
143+
has_equations(::TransformationState) = true
143144

144145
Base.@kwdef mutable struct SystemStructure
145146
# Maps the (index of) a variable to the (index of) the variable describing

0 commit comments

Comments
 (0)