Skip to content

Commit 9e2d60a

Browse files
committed
Make consistency check more generic
1 parent 6c31043 commit 9e2d60a

File tree

4 files changed

+10
-7
lines changed

4 files changed

+10
-7
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
2323
IncrementalCycleTracker, add_edge_checked!, topological_sort,
2424
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
2525
AliasGraph, filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
26-
fast_substitute
26+
fast_substitute, get_fullvars, has_equations
2727

2828
using ModelingToolkit.BipartiteGraphs
2929
import .BipartiteGraphs: invview, complete

src/structural_transformation/utils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ end
1414

1515
function error_reporting(state, bad_idxs, n_highest_vars, iseqs, orig_inputs)
1616
io = IOBuffer()
17-
neqs = length(equations(state))
17+
neqs = length(ndsts(state.structure.graph))
1818
if iseqs
1919
error_title = "More equations than variables, here are the potential extra equation(s):\n"
20-
out_arr = equations(state)[bad_idxs]
20+
out_arr = has_equations(state) ? equations(state)[bad_idxs] : bad_idxs
2121
else
2222
error_title = "More variables than equations, here are the potential extra variable(s):\n"
23-
out_arr = state.fullvars[bad_idxs]
23+
out_arr = get_fullvars(state)[bad_idxs]
2424
unset_inputs = intersect(out_arr, orig_inputs)
2525
n_missing_eqs = n_highest_vars - neqs
2626
n_unset_inputs = length(unset_inputs)
@@ -52,8 +52,8 @@ end
5252
###
5353
### Structural check
5454
###
55-
function check_consistency(state::TearingState, ag, orig_inputs)
56-
fullvars = state.fullvars
55+
function check_consistency(state::TransformationState, ag, orig_inputs)
56+
fullvars = get_fullvars(state)
5757
@unpack graph, var_to_diff = state.structure
5858
n_highest_vars = count(v -> var_to_diff[v] === nothing &&
5959
!isempty(𝑑neighbors(graph, v)) &&

src/systems/abstractsystem.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ for prop in [:eqs
238238
end
239239
end
240240

241+
has_equations(::AbstractSystem) = true
242+
241243
const EMPTY_TGRAD = Vector{Num}(undef, 0)
242244
const EMPTY_JAC = Matrix{Num}(undef, 0, 0)
243245
function invalidate_cache!(sys::AbstractSystem)

src/systems/systemstructure.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten,
1010
isparameter, isconstant,
1111
independent_variables, SparseMatrixCLIL, AbstractSystem,
1212
equations, isirreducible, input_timedomain, TimeDomain,
13-
VariableType, getvariabletype
13+
VariableType, getvariabletype, has_equations
1414
using ..BipartiteGraphs
1515
import ..BipartiteGraphs: invview, complete
1616
using Graphs
@@ -140,6 +140,7 @@ abstract type TransformationState{T} end
140140
abstract type AbstractTearingState{T} <: TransformationState{T} end
141141

142142
get_fullvars(ts::TransformationState) = ts.fullvars
143+
has_equations(::TransformationState) = true
143144

144145
Base.@kwdef mutable struct SystemStructure
145146
# Maps the (index of) a variable to the (index of) the variable describing

0 commit comments

Comments
 (0)