Skip to content

Commit 695d62b

Browse files
committed
Adding validation of Unitful quantities. Doesn't handle symbolic arrays yet.
1 parent c7b00d6 commit 695d62b

File tree

5 files changed

+114
-32
lines changed

5 files changed

+114
-32
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ struct ODESystem <: AbstractODESystem
8888
check_variables(dvs,iv)
8989
check_parameters(ps,iv)
9090
check_equations(deqs,iv)
91+
validate(deqs)
9192
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
9293
end
9394
end

src/systems/diffeqs/sdesystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ struct SDESystem <: AbstractODESystem
9090
check_variables(dvs,iv)
9191
check_parameters(ps,iv)
9292
check_equations(deqs,iv)
93+
validate(deqs)
9394
new(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
9495
end
9596
end

src/systems/diffeqs/validation.jl

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,72 @@
11
Base.:*(x::Union{Num,Symbolic},y::Unitful.AbstractQuantity) = x * y
22

3-
instantiate(x::Sym{Real}) = 1.0
4-
instantiate(x::Symbolic) = oneunit(1*ModelingToolkit.vartype(x))
5-
function instantiate(x::Num)
6-
x = value(x)
7-
if operation(x) isa Sym
8-
return instantiate(operation(x))
9-
elseif operation(x) isa Differential
10-
instantiate(arguments(x)[1])/instantiate(arguments(x)[1].args[1])
3+
4+
function vartype(x::Symbolic)
5+
if !(x.metadata isa Nothing)
6+
return haskey(x.metadata,VariableUnit) ? x.metadata[VariableUnit] : 1.0
7+
end
8+
1.0
9+
end
10+
vartype(x::Num) = vartype(value(x))
11+
12+
instantiate(x) = 1.0
13+
instantiate(x::Num) = instantiate(value(x))
14+
function instantiate(x::Symbolic)
15+
vx = value(x)
16+
if vx isa Sym || operation(vx) isa Sym
17+
return oneunit(1 * ModelingToolkit.vartype(x))
18+
elseif operation(vx) isa Differential
19+
return instantiate(arguments(vx)[1]) / instantiate(arguments(arguments(vx)[1])[1])
20+
elseif vx isa Pow
21+
pargs = arguments(vx)
22+
base,expon = instantiate.(pargs)
23+
uconvert(NoUnits, expon) # This acts as an assertion
24+
return base == 1.0 ? 1.0 : operation(vx)(base, pargs[2])
25+
elseif vx isa Add # Cannot simply add the units b/c they may differ in magnitude (eg, kg vs g)
26+
terms = instantiate.(arguments(vx))
27+
firstunit = unit(terms[1])
28+
@assert all(map(x -> ustrip(firstunit, x) == 1, terms[2:end]))
29+
return 1.0 * firstunit
1130
else
12-
operation(x)(instantiate.(arguments(x))...)
31+
return oneunit(operation(vx)(instantiate.(arguments(vx))...))
1332
end
1433
end
1534

16-
function validate(eq::ModelingToolkit.Equation)
35+
function validate(eq::ModelingToolkit.Equation; eqnum = 1)
36+
lhs = rhs = nothing
1737
try
18-
return typeof(instantiate(eq.lhs)) == typeof(instantiate(eq.rhs))
19-
catch
20-
return false
38+
lhs = instantiate(eq.lhs)
39+
catch err
40+
if err isa Unitful.DimensionError
41+
@warn("In left-hand side of eq. #$eqnum: $(eq.lhs), $(err.x) and $(err.y) are not dimensionally compatible.")
42+
elseif err isa MethodError
43+
@warn("In right-hand side of eq. #$eqnum: $(err.f) doesn't accept $(err.args).")
44+
else
45+
rethrow()
46+
end
2147
end
48+
try
49+
rhs = instantiate(eq.rhs)
50+
catch err
51+
if err isa Unitful.DimensionError
52+
@warn("In right-hand side of eq. #$eqnum: $(eq.rhs), $(err.x) and $(err.y) are not dimensionally compatible.")
53+
elseif err isa MethodError
54+
@warn("In right-hand side of eq. #$eqnum: $(err.f) doesn't accept $(err.args).")
55+
else
56+
rethrow()
57+
end
58+
end
59+
if (rhs !== nothing) && (lhs !== nothing)
60+
if !isequal(lhs, rhs)
61+
@warn("In eq. #$eqnum, left-side units ($lhs) and right-side units ($rhs) don't match.")
62+
end
63+
end
64+
(rhs !== nothing) && (lhs !== nothing) && isequal(lhs, rhs)
2265
end
2366

24-
function validate(sys::AbstractODESystem)
25-
all(validate.(equations(sys)))
67+
function validate(eqs::Vector{ModelingToolkit.Equation})
68+
correct = [validate(eqs[idx],eqnum=idx) for idx in 1:length(eqs)]
69+
all(correct) || throw(ArgumentError("Invalid equations, see warnings for details."))
2670
end
71+
72+
validate(sys::AbstractODESystem) = validate(equations(sys))

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using SafeTestsets, Test
88
@safetestset "System Linearity Test" begin include("linearity.jl") end
99
@safetestset "DiscreteSystem Test" begin include("discretesystem.jl") end
1010
@safetestset "ODESystem Test" begin include("odesystem.jl") end
11+
@safetestset "Unitful Quantities Test" begin include("units.jl") end
1112
@safetestset "LabelledArrays Test" begin include("labelledarrays.jl") end
1213
@safetestset "Mass Matrix Test" begin include("mass_matrix.jl") end
1314
@safetestset "SteadyStateSystem Test" begin include("steadystatesystems.jl") end

test/units.jl

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,58 @@
11
using ModelingToolkit, Unitful
22
using Test
3-
4-
t = Variable{u"s"}(:t)()
5-
x = Variable{u"kg"}(:x)(t)
6-
y = Variable{u"kg"}(:y)(t)
3+
MT = ModelingToolkit
4+
@parameters τ [unit=u"ms"]
5+
@variables t [unit=u"ms"] E(t) [unit=u"kJ"] P(t) [unit=u"MW"]
76
D = Differential(t)
87

9-
eq1 = x ~ y*t
10-
eq2 = x*10u"s" ~ y*t
8+
@test MT.vartype(t) == u"ms"
9+
@test MT.vartype(E) == u"kJ"
10+
@test MT.vartype(τ) == u"ms"
11+
12+
eqs = [D(E) ~ P-E/τ ]
13+
sys = ODESystem(eqs)
14+
15+
@test MT.instantiate(eqs[1].lhs) == 1.0u"MW"
16+
@test MT.instantiate(eqs[1].rhs) == 1.0u"MW"
17+
@test MT.validate(eqs[1])
18+
@test MT.validate(sys)
19+
20+
@test MT.instantiate(0.5) == 1.0
21+
@test MT.instantiate(t) == 1.0u"ms"
22+
@test MT.instantiate(P) == 1.0u"MW"
23+
@test MT.instantiate(τ) == 1.0u"ms"
1124

12-
@test ModelingToolkit.instantiate(t) == 1u"s"
13-
@test ModelingToolkit.instantiate(x) == 1u"kg"
14-
@test ModelingToolkit.instantiate(y) == 1u"kg"
25+
@test MT.instantiate^-1) == 1/u"ms"
26+
@test MT.instantiate(D(E)) == 1.0u"MW"
27+
@test MT.instantiate(E/τ) == 1.0u"MW"
28+
@test MT.instantiate(2*P) == 1.0u"MW"
29+
@test MT.instantiate(t/τ) == 1.0
30+
@test MT.instantiate(P-E/τ)/1.0u"MW" == 1.0
1531

16-
@test !ModelingToolkit.validate(eq1)
17-
@test ModelingToolkit.validate(eq2)
32+
@test MT.instantiate(1.0^(t/τ)) == 1.0
33+
@test MT.instantiate(exp(t/τ)) == 1.0
34+
@test MT.instantiate(sin(t/τ)) == 1.0
35+
@test MT.instantiate(sin(1u"rad")) == 1.0
36+
@test MT.instantiate(t^2) == 1.0u"ms"^2
1837

19-
eqs = [
20-
D(x) ~ y/t
21-
D(y) ~ (x*y)/(t*10u"kg")
22-
]
38+
@test !MT.validate(E^1.5~ E^(t/τ))
39+
@test MT.validate(E^(t/τ)~ E^(t/τ))
2340

24-
sys = ODESystem(eqs,t,[x,y],[])
25-
@test ModelingToolkit.validate(sys)
41+
sys = ODESystem(eqs,t,[P,E],[τ])
42+
@test MT.validate(sys)
43+
44+
@test !MT.validate(D(D(E))~P)
45+
@test !MT.validate(0~P+E*τ)
46+
@test_logs (:warn,) MT.validate(0 ~ P + E*τ)
47+
@test_logs (:warn,) MT.validate(P + E*τ ~ 0)
48+
@test_logs (:warn,) MT.validate(P ~ 0)
49+
50+
@variables x y z u
51+
@parameters σ ρ β
52+
eqs = [0 ~ σ*(y-x)]
53+
@test MT.validate(eqs) #should cope with unit-free
54+
55+
@variables t x[1:3,1:3](t) #should cope with arrays
56+
D = Differential(t)
57+
eqs = D.(x) .~ x
58+
ODESystem(eqs)

0 commit comments

Comments
 (0)