Skip to content

Commit 084541f

Browse files
authored
Merge pull request #1929 from contradict/validate-connection
Add unit validation for Connections
2 parents 2eaf863 + 36adaed commit 084541f

File tree

2 files changed

+70
-5
lines changed

2 files changed

+70
-5
lines changed

src/systems/validation.jl

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ MT = ModelingToolkit
2929
```
3030
"""
3131
equivalent(x, y) = isequal(1 * x, 1 * y)
32-
unitless = Unitful.unit(1)
32+
const unitless = Unitful.unit(1)
3333

3434
#For dispatching get_unit
35-
Literal = Union{Sym, Symbolics.ArrayOp, Symbolics.Arr, Symbolics.CallWithMetadata}
36-
Conditional = Union{typeof(ifelse), typeof(IfElse.ifelse)}
37-
Comparison = Union{typeof.([==, !=, , <, <=, , >, >=, ])...}
35+
const Literal = Union{Sym, Symbolics.ArrayOp, Symbolics.Arr, Symbolics.CallWithMetadata}
36+
const Conditional = Union{typeof(ifelse), typeof(IfElse.ifelse)}
37+
const Comparison = Union{typeof.([==, !=, , <, <=, , >, >=, ])...}
3838

3939
"Find the unit of a symbolic item."
4040
get_unit(x::Real) = unitless
@@ -170,6 +170,37 @@ function _validate(terms::Vector, labels::Vector{String}; info::String = "")
170170
valid
171171
end
172172

173+
function _validate(conn::Connection; info::String = "")
174+
valid = true
175+
syss = get_systems(conn)
176+
sys = first(syss)
177+
st = states(sys)
178+
for i in 2:length(syss)
179+
s = syss[i]
180+
sst = states(s)
181+
if length(st) != length(sst)
182+
valid = false
183+
@warn("$info: connected systems $(nameof(sys)) and $(nameof(s)) have $(length(st)) and $(length(sst)) states, cannor connect.")
184+
continue
185+
end
186+
for (i, x) in enumerate(st)
187+
j = findfirst(isequal(x), sst)
188+
if j == nothing
189+
valid = false
190+
@warn("$info: connected systems $(nameof(sys)) and $(nameof(s)) do not have the same states.")
191+
else
192+
aunit = safe_get_unit(x, info * string(nameof(sys)) * "#$i")
193+
bunit = safe_get_unit(sst[j], info * string(nameof(s)) * "#$j")
194+
if !equivalent(aunit, bunit)
195+
valid = false
196+
@warn("$info: connected system states $x and $(sst[j]) have mismatched units.")
197+
end
198+
end
199+
end
200+
end
201+
valid
202+
end
203+
173204
function validate(jump::Union{ModelingToolkit.VariableRateJump,
174205
ModelingToolkit.ConstantRateJump}, t::Symbolic;
175206
info::String = "")
@@ -195,7 +226,11 @@ function validate(jumps::ArrayPartition{<:Union{Any, Vector{<:JumpType}}}, t::Sy
195226
end
196227

197228
function validate(eq::ModelingToolkit.Equation; info::String = "")
198-
_validate([eq.lhs, eq.rhs], ["left", "right"]; info)
229+
if typeof(eq.lhs) == Connection
230+
_validate(eq.rhs; info)
231+
else
232+
_validate([eq.lhs, eq.rhs], ["left", "right"]; info)
233+
end
199234
end
200235
function validate(eq::ModelingToolkit.Equation,
201236
term::Union{Symbolic, Unitful.Quantity, Num}; info::String = "")

test/units.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,36 @@ ODESystem(eqs, name = :sys, checks = false)
6262
@test_throws MT.ValidationError ODESystem(eqs, t, [E, P, t], [τ], name = :sys,
6363
checks = MT.CheckUnits)
6464

65+
# connection validation
66+
@connector function Pin(; name)
67+
sts = @variables(v(t)=1.0, [unit = u"V"],
68+
i(t)=1.0, [unit = u"A", connect = Flow])
69+
ODESystem(Equation[], t, sts, []; name = name)
70+
end
71+
@connector function OtherPin(; name)
72+
sts = @variables(v(t)=1.0, [unit = u"mV"],
73+
i(t)=1.0, [unit = u"mA", connect = Flow])
74+
ODESystem(Equation[], t, sts, []; name = name)
75+
end
76+
@connector function LongPin(; name)
77+
sts = @variables(v(t)=1.0, [unit = u"V"],
78+
i(t)=1.0, [unit = u"A", connect = Flow],
79+
x(t)=1.0, [unit = NoUnits])
80+
ODESystem(Equation[], t, sts, []; name = name)
81+
end
82+
@named p1 = Pin()
83+
@named p2 = Pin()
84+
@named op = OtherPin()
85+
@named lp = LongPin()
86+
good_eqs = [connect(p1, p2)]
87+
bad_eqs = [connect(p1, p2, op)]
88+
bad_length_eqs = [connect(op, lp)]
89+
@test MT.validate(good_eqs)
90+
@test !MT.validate(bad_eqs)
91+
@test !MT.validate(bad_length_eqs)
92+
@named sys = ODESystem(good_eqs, t, [], [])
93+
@test_throws MT.ValidationError ODESystem(bad_eqs, t, [], []; name = :sys)
94+
6595
# Array variables
6696
@variables t [unit = u"s"] x(t)[1:3] [unit = u"m"]
6797
@parameters v[1:3]=[1, 2, 3] [unit = u"m/s"]

0 commit comments

Comments
 (0)