Skip to content

Commit 5dc53f0

Browse files
Merge pull request #1086 from bolognam/mdb/unbalanced-error-messages
Return Unbalanced Equations/Variables
2 parents 8ff75e3 + 1b8134c commit 5dc53f0

File tree

7 files changed

+131
-14
lines changed

7 files changed

+131
-14
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ using SymbolicUtils.Rewriters
1212
using SymbolicUtils: similarterm, istree
1313

1414
using ModelingToolkit
15-
using ModelingToolkit: ODESystem, var_from_nested_derivative, Differential,
16-
states, equations, vars, Symbolic, diff2term, value,
17-
operation, arguments, Sym, Term, simplify, solve_for,
18-
isdiffeq, isdifferential,
19-
get_structure, defaults, InvalidSystemException
15+
using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative,
16+
Differential, states, equations, vars, Symbolic,
17+
diff2term, value, operation, arguments, Sym, Term,
18+
simplify, solve_for, isdiffeq, isdifferential,
19+
get_structure, defaults, InvalidSystemException,
20+
ExtraEquationsSystemException,
21+
ExtraVariablesSystemException
2022

2123
using ModelingToolkit.BipartiteGraphs
2224
using ModelingToolkit.SystemStructures

src/structural_transformation/utils.jl

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,61 @@ function matching(g::BipartiteGraph, varwhitelist=nothing, eqwhitelist=nothing)
4747
return assign
4848
end
4949

50+
function error_reporting(sys, bad_idxs, n_highest_vars, iseqs)
51+
io = IOBuffer()
52+
if iseqs
53+
error_title = "More equations than variables, here are the potential extra equation(s):\n"
54+
out_arr = equations(sys)[bad_idxs]
55+
else
56+
error_title = "More variables than equations, here are the potential extra variable(s):\n"
57+
out_arr = structure(sys).fullvars[bad_idxs]
58+
end
59+
60+
Base.print_array(io, out_arr)
61+
msg = String(take!(io))
62+
neqs = length(equations(sys))
63+
if iseqs
64+
throw(ExtraEquationsSystemException(
65+
"The system is unbalanced. "
66+
* "There are $n_highest_vars highest order derivative variables "
67+
* "and $neqs equations.\n"
68+
* error_title
69+
* msg
70+
))
71+
else
72+
throw(ExtraVariablesSystemException(
73+
"The system is unbalanced. "
74+
* "There are $n_highest_vars highest order derivative variables "
75+
* "and $neqs equations.\n"
76+
* error_title
77+
* msg
78+
))
79+
end
80+
end
81+
5082
###
5183
### Structural check
5284
###
53-
function check_consistency(s::SystemStructure)
85+
function check_consistency(sys::AbstractSystem)
86+
s = structure(sys)
5487
@unpack varmask, graph, varassoc, fullvars = s
5588
n_highest_vars = count(varmask)
5689
neqs = nsrcs(graph)
5790
is_balanced = n_highest_vars == neqs
5891

59-
(neqs > 0 && !is_balanced) && throw(InvalidSystemException(
60-
"The system is unbalanced. "
61-
* "There are $n_highest_vars highest order derivative variables "
62-
* "and $neqs equations."
63-
))
92+
if neqs > 0 && !is_balanced
93+
varwhitelist = varassoc .== 0
94+
assign = matching(graph, varwhitelist) # not assigned
95+
# Just use `error_reporting` to do conditional
96+
iseqs = n_highest_vars < neqs
97+
if iseqs
98+
inv_assign = inverse_mapping(assign) # extra equations
99+
bad_idxs = findall(iszero, @view inv_assign[1:nsrcs(graph)])
100+
else
101+
bad_idxs = findall(isequal(UNASSIGNED), assign)
102+
end
103+
error_reporting(sys, bad_idxs, n_highest_vars, iseqs)
104+
end
64105

65106
# This is defined to check if Pantelides algorithm terminates. For more
66107
# details, check the equation (15) of the original paper.

src/systems/abstractsystem.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ topological sort of the observed equations.
667667
"""
668668
function structural_simplify(sys::AbstractSystem)
669669
sys = initialize_system_structure(alias_elimination(sys))
670-
check_consistency(structure(sys))
670+
check_consistency(sys)
671671
if sys isa ODESystem
672672
sys = dae_index_lowering(sys)
673673
end
@@ -688,6 +688,16 @@ struct InvalidSystemException <: Exception
688688
end
689689
Base.showerror(io::IO, e::InvalidSystemException) = print(io, "InvalidSystemException: ", e.msg)
690690

691+
struct ExtraVariablesSystemException <: Exception
692+
msg::String
693+
end
694+
Base.showerror(io::IO, e::ExtraVariablesSystemException) = print(io, "ExtraVariablesSystemException: ", e.msg)
695+
696+
struct ExtraEquationsSystemException <: Exception
697+
msg::String
698+
end
699+
Base.showerror(io::IO, e::ExtraEquationsSystemException) = print(io, "ExtraEquationsSystemException: ", e.msg)
700+
691701
AbstractTrees.children(sys::ModelingToolkit.AbstractSystem) = ModelingToolkit.get_systems(sys)
692702
AbstractTrees.printnode(io::IO, sys::ModelingToolkit.AbstractSystem) = print(io, nameof(sys))
693703
AbstractTrees.nodetype(::ModelingToolkit.AbstractSystem) = ModelingToolkit.AbstractSystem

src/systems/systemstructure.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,10 @@ function initialize_system_structure(sys)
180180
for algvar in algvars
181181
# it could be that a variable appeared in the states, but never appeared
182182
# in the equations.
183-
algvaridx = var2idx[algvar]
183+
algvaridx = get(var2idx, algvar, 0)
184+
algvaridx == 0 && throw(InvalidSystemException("The system is missing "
185+
* "an equation for $algvar."
186+
))
184187
vartype[algvaridx] = ALGEBRAIC_VARIABLE
185188
end
186189

test/error_handling.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
using Test
2+
using ModelingToolkit
3+
import ModelingToolkit: ExtraVariablesSystemException, ExtraEquationsSystemException
4+
5+
include("../examples/electrical_components.jl")
6+
7+
function UnderdefinedConstantVoltage(;name, V = 1.0)
8+
val = V
9+
@named p = Pin()
10+
@named n = Pin()
11+
@parameters V
12+
eqs = [
13+
V ~ p.v - n.v
14+
# Remove equation
15+
# 0 ~ p.i + n.i
16+
]
17+
ODESystem(eqs, t, [], [V], systems=[p, n], defaults=Dict(V => val), name=name)
18+
end
19+
20+
function OverdefinedConstantVoltage(;name, V = 1.0, I = 1.0)
21+
val = V
22+
val2 = I
23+
@named p = Pin()
24+
@named n = Pin()
25+
@parameters V I
26+
eqs = [
27+
V ~ p.v - n.v
28+
# Overdefine p.i and n.i
29+
n.i ~ I
30+
p.i ~ I
31+
]
32+
ODESystem(eqs, t, [], [V], systems=[p, n], defaults=Dict(V => val, I => val2), name=name)
33+
end
34+
35+
R = 1.0
36+
C = 1.0
37+
V = 1.0
38+
@named resistor = Resistor(R=R)
39+
@named capacitor = Capacitor(C=C)
40+
@named source = UnderdefinedConstantVoltage(V=V)
41+
42+
rc_eqs = [
43+
connect(source.p, resistor.p)
44+
connect(resistor.n, capacitor.p)
45+
connect(capacitor.n, source.n)
46+
]
47+
48+
@named rc_model = ODESystem(rc_eqs, t, systems=[resistor, capacitor, source])
49+
@test_throws ModelingToolkit.ExtraVariablesSystemException structural_simplify(rc_model)
50+
51+
52+
@named source2 = OverdefinedConstantVoltage(V=V, I=V/R)
53+
rc_eqs2 = [
54+
connect(source2.p, resistor.p)
55+
connect(resistor.n, capacitor.p)
56+
connect(capacitor.n, source2.n)
57+
]
58+
59+
@named rc_model2 = ODESystem(rc_eqs2, t, systems=[resistor, capacitor, source2])
60+
@test_throws ModelingToolkit.ExtraEquationsSystemException structural_simplify(rc_model2)

test/reduction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ eqs = [
233233
]
234234

235235
@named sys = ODESystem(eqs, t, [E, C, S, P], [k₁, k₂, k₋₁, E₀])
236-
@test_throws ModelingToolkit.InvalidSystemException structural_simplify(sys)
236+
@test_throws ModelingToolkit.ExtraEquationsSystemException structural_simplify(sys)
237237

238238
# Example 5 from Pantelides' original paper
239239
@parameters t

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@ println("Last test requires gcc available in the path!")
3838
@testset "Serialization" begin include("serialization.jl") end
3939
@safetestset "print_tree" begin include("print_tree.jl") end
4040
@safetestset "connectors" begin include("connectors.jl") end
41+
@safetestset "error_handling" begin include("error_handling.jl") end

0 commit comments

Comments
 (0)