Skip to content

Commit 978d022

Browse files
committed
Add unit validation for Connections
1 parent bbe5081 commit 978d022

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

src/systems/validation.jl

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,36 @@ function _validate(terms::Vector, labels::Vector{String}; info::String = "")
170170
valid
171171
end
172172

173+
function _validate(conn::Connection; info::String="")
174+
valid = true
175+
syss = get_systems(conn)
176+
sys = first(syss)
177+
st = states(sys)
178+
for s in syss[2:end]
179+
sst = states(s)
180+
if length(st) != length(sst)
181+
valid = false
182+
@warn("$info: connected systems $(nameof(sys)) and $(nameof(s)) have $(length(st)) and $(length(sst)) states, cannor connect.")
183+
continue
184+
end
185+
for (i, x) in enumerate(st)
186+
j = findfirst(isequal(x), sst)
187+
if j == nothing
188+
valid = false
189+
@warn("$info: connected systems $(nameof(sys)) and $(nameof(s)) do not have the same states.")
190+
else
191+
aunit = safe_get_unit(x, info * string(nameof(sys)) * "#$i")
192+
bunit = safe_get_unit(sst[j], info * string(nameof(s)) * "#$j")
193+
if !equivalent(aunit, bunit)
194+
valid = false
195+
@warn("$info: connected system states $x and $(sst[j]) have mismatched units.")
196+
end
197+
end
198+
end
199+
end
200+
valid
201+
end
202+
173203
function validate(jump::Union{ModelingToolkit.VariableRateJump,
174204
ModelingToolkit.ConstantRateJump}, t::Symbolic;
175205
info::String = "")
@@ -195,7 +225,11 @@ function validate(jumps::ArrayPartition{<:Union{Any, Vector{<:JumpType}}}, t::Sy
195225
end
196226

197227
function validate(eq::ModelingToolkit.Equation; info::String = "")
198-
_validate([eq.lhs, eq.rhs], ["left", "right"]; info)
228+
if typeof(eq.lhs) == Connection
229+
_validate(eq.rhs; info)
230+
else
231+
_validate([eq.lhs, eq.rhs], ["left", "right"]; info)
232+
end
199233
end
200234
function validate(eq::ModelingToolkit.Equation,
201235
term::Union{Symbolic, Unitful.Quantity, Num}; info::String = "")

test/units.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,42 @@ ODESystem(eqs, name = :sys, checks = false)
6262
@test_throws MT.ValidationError ODESystem(eqs, t, [E, P, t], [τ], name = :sys,
6363
checks = MT.CheckUnits)
6464

65+
# connection validation
66+
@connector function Pin(;name)
67+
sts = @variables(
68+
v(t)=1.0, [unit=u"V"],
69+
i(t)=1.0, [unit=u"A", connect = Flow]
70+
)
71+
ODESystem(Equation[], t, sts, []; name=name)
72+
end
73+
@connector function OtherPin(;name)
74+
sts = @variables(
75+
v(t)=1.0, [unit=u"mV"],
76+
i(t)=1.0, [unit=u"mA", connect = Flow]
77+
)
78+
ODESystem(Equation[], t, sts, []; name=name)
79+
end
80+
@connector function LongPin(;name)
81+
sts = @variables(
82+
v(t)=1.0, [unit=u"V"],
83+
i(t)=1.0, [unit=u"A", connect = Flow],
84+
x(t)=1.0, [unit=NoUnits]
85+
)
86+
ODESystem(Equation[], t, sts, []; name=name)
87+
end
88+
@named p1 = Pin()
89+
@named p2 = Pin()
90+
@named op = OtherPin()
91+
@named lp = LongPin()
92+
good_eqs = [connect(p1, p2)]
93+
bad_eqs = [connect(p1, p2, op)]
94+
bad_length_eqs = [connect(op, lp)]
95+
@test MT.validate(good_eqs)
96+
@test !MT.validate(bad_eqs)
97+
@test !MT.validate(bad_length_eqs)
98+
@named sys = ODESystem(good_eqs, t, [], [])
99+
@test_throws MT.ValidationError ODESystem(bad_eqs, t, [], []; name = :sys)
100+
65101
# Array variables
66102
@variables t [unit = u"s"] x(t)[1:3] [unit = u"m"]
67103
@parameters v[1:3]=[1, 2, 3] [unit = u"m/s"]

0 commit comments

Comments
 (0)