Skip to content

Commit da56c3f

Browse files
committed
Add working tests
1 parent ce47f53 commit da56c3f

File tree

5 files changed

+143
-13
lines changed

5 files changed

+143
-13
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ using SymbolicUtils.Rewriters
1212
using SymbolicUtils: similarterm, istree
1313

1414
using ModelingToolkit
15-
using ModelingToolkit: ODESystem, var_from_nested_derivative, Differential,
16-
states, equations, vars, Symbolic, diff2term, value,
17-
operation, arguments, Sym, Term, simplify, solve_for,
18-
isdiffeq, isdifferential,
19-
get_structure, defaults, InvalidSystemException
15+
using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative,
16+
Differential, states, equations, vars, Symbolic,
17+
diff2term, value, operation, arguments, Sym, Term,
18+
simplify, solve_for, isdiffeq, isdifferential,
19+
get_structure, defaults, InvalidSystemException,
20+
ExtraEquationsSystemException,
21+
ExtraVariablesSystemException
2022

2123
using ModelingToolkit.BipartiteGraphs
2224
using ModelingToolkit.SystemStructures

src/structural_transformation/utils.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,26 @@ function error_reporting(sys, bad_idxs, n_highest_vars, iseqs)
5757
out_arr = structure(sys).fullvars[bad_idxs]
5858
end
5959

60-
msg = String(take!(Base.print_array(io, out_arr)))
60+
Base.print_array(io, out_arr)
61+
msg = String(take!(io))
6162
neqs = length(equations(sys))
62-
throw(InvalidSystemException(
63-
"The system is unbalanced. "
64-
* "There are $n_highest_vars highest order derivative variables "
65-
* "and $neqs equations.\n"
66-
* msg
67-
))
63+
if iseqs
64+
throw(ExtraEquationsSystemException(
65+
"The system is unbalanced. "
66+
* "There are $n_highest_vars highest order derivative variables "
67+
* "and $neqs equations.\n"
68+
* error_title
69+
* msg
70+
))
71+
else
72+
throw(ExtraVariablesSystemException(
73+
"The system is unbalanced. "
74+
* "There are $n_highest_vars highest order derivative variables "
75+
* "and $neqs equations.\n"
76+
* error_title
77+
* msg
78+
))
79+
end
6880
end
6981

7082
###
@@ -81,7 +93,7 @@ function check_consistency(sys::AbstractSystem)
8193
varwhitelist = varassoc .== 0
8294
assign = matching(graph, varwhitelist) # not assigned
8395
# Just use `error_reporting` to do conditional
84-
iseqs = n_highest_vars > neqs
96+
iseqs = n_highest_vars < neqs
8597

8698
if iseqs
8799
bad_idxs = findall(isequal(UNASSIGNED), assign)

src/systems/abstractsystem.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,16 @@ struct InvalidSystemException <: Exception
665665
end
666666
Base.showerror(io::IO, e::InvalidSystemException) = print(io, "InvalidSystemException: ", e.msg)
667667

668+
struct ExtraVariablesSystemException <: Exception
669+
msg::String
670+
end
671+
Base.showerror(io::IO, e::ExtraVariablesSystemException) = print(io, "ExtraVariablesSystemException: ", e.msg)
672+
673+
struct ExtraEquationsSystemException <: Exception
674+
msg::String
675+
end
676+
Base.showerror(io::IO, e::ExtraEquationsSystemException) = print(io, "ExtraEquationsSystemException: ", e.msg)
677+
668678
AbstractTrees.children(sys::ModelingToolkit.AbstractSystem) = ModelingToolkit.get_systems(sys)
669679
AbstractTrees.printnode(io::IO, sys::ModelingToolkit.AbstractSystem) = print(io, nameof(sys))
670680
AbstractTrees.nodetype(::ModelingToolkit.AbstractSystem) = ModelingToolkit.AbstractSystem

test/error_handling.jl

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
using Test
2+
using ModelingToolkit#, OrdinaryDiffEq
3+
import ModelingToolkit: ExtraVariablesSystemException, ExtraEquationsSystemException
4+
5+
@parameters t
6+
function Pin(;name)
7+
@variables v(t) i(t)
8+
ODESystem(Equation[], t, [v, i], [], name=name, defaults=[v=>1.0, i=>1.0])
9+
end
10+
11+
function UnderdefinedConstantVoltage(;name, V = 1.0)
12+
val = V
13+
@named p = Pin()
14+
@named n = Pin()
15+
@parameters V
16+
eqs = [
17+
V ~ p.v - n.v
18+
#0 ~ p.i + n.i
19+
]
20+
ODESystem(eqs, t, [], [V], systems=[p, n], defaults=Dict(V => val), name=name)
21+
end
22+
23+
function OverdefinedConstantVoltage(;name, V = 1.0, I = 1.0)
24+
val = V
25+
val2 = I
26+
@named p = Pin()
27+
@named n = Pin()
28+
@parameters V I
29+
eqs = [
30+
V ~ p.v - n.v
31+
n.i ~ I
32+
p.i ~ I
33+
]
34+
ODESystem(eqs, t, [], [V], systems=[p, n], defaults=Dict(V => val, I => val2), name=name)
35+
end
36+
37+
function Resistor(;name, R = 1.0)
38+
val = R
39+
@named p = Pin()
40+
@named n = Pin()
41+
@variables v(t)
42+
@parameters R
43+
eqs = [
44+
v ~ p.v - n.v
45+
0 ~ p.i + n.i
46+
v ~ p.i * R
47+
]
48+
ODESystem(eqs, t, [v], [R], systems=[p, n], defaults=Dict(R => val), name=name)
49+
end
50+
51+
function Capacitor(;name, C = 1.0)
52+
val = C
53+
@named p = Pin()
54+
@named n = Pin()
55+
@variables v(t)
56+
@parameters C
57+
D = Differential(t)
58+
eqs = [
59+
v ~ p.v - n.v
60+
0 ~ p.i + n.i
61+
D(v) ~ p.i / C
62+
]
63+
ODESystem(eqs, t, [v], [C], systems=[p, n], defaults=Dict(C => val), name=name)
64+
end
65+
66+
function ModelingToolkit.connect(ps...)
67+
eqs = [
68+
0 ~ sum(p->p.i, ps) # KCL
69+
]
70+
# KVL
71+
for i in 1:length(ps)-1
72+
push!(eqs, ps[i].v ~ ps[i+1].v)
73+
end
74+
75+
return eqs
76+
end
77+
78+
R = 1.0
79+
C = 1.0
80+
V = 1.0
81+
@named resistor = Resistor(R=R)
82+
@named capacitor = Capacitor(C=C)
83+
@named source = UnderdefinedConstantVoltage(V=V)
84+
85+
rc_eqs = [
86+
connect(source.p, resistor.p)
87+
connect(resistor.n, capacitor.p)
88+
connect(capacitor.n, source.n)
89+
]
90+
91+
@named rc_model = ODESystem(rc_eqs, t, systems=[resistor, capacitor, source])
92+
93+
94+
@test_throws ModelingToolkit.ExtraVariablesSystemException structural_simplify(rc_model)
95+
96+
97+
@named source2 = OverdefinedConstantVoltage(V=V, I=V/R)
98+
rc_eqs2 = [
99+
connect(source2.p, resistor.p)
100+
connect(resistor.n, capacitor.p)
101+
connect(capacitor.n, source2.n)
102+
]
103+
104+
@named rc_model2 = ODESystem(rc_eqs2, t, systems=[resistor, capacitor, source2])
105+
@test_throws ModelingToolkit.ExtraEquationsSystemException structural_simplify(rc_model2)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@ println("Last test requires gcc available in the path!")
3838
@testset "Serialization" begin include("serialization.jl") end
3939
@safetestset "print_tree" begin include("print_tree.jl") end
4040
@safetestset "connectors" begin include("connectors.jl") end
41+
@safetestset "error_handling" begin include("error_handling.jl") end

0 commit comments

Comments
 (0)