Skip to content

Commit b1d736d

Browse files
committed
Move Unitful checks to a separate module
1 parent 90789c0 commit b1d736d

File tree

10 files changed

+73
-59
lines changed

10 files changed

+73
-59
lines changed

src/systems/diffeqs/sdesystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ struct SDESystem <: AbstractODESystem
134134
check_equations(equations(cevents), iv)
135135
end
136136
if checks == true || (checks & CheckUnits) > 0
137-
all_dimensionless([dvs; ps; iv]) || check_units(deqs, neqs)
137+
u = __get_unit_type(dvs, ps, iv)
138+
check_units(u, deqs, neqs)
138139
end
139140
new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
140141
ctrl_jac,

src/systems/jumps/jumpsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
112112
check_parameters(ps, iv)
113113
end
114114
if checks == true || (checks & CheckUnits) > 0
115-
all_dimensionless([states; ps; iv]) || check_units(ap, iv)
115+
u = __get_unit_type(states, ps, iv)
116+
check_units(u, ap, iv)
116117
end
117118
new{U}(tag, ap, iv, states, ps, var_to_name, observed, name, systems, defaults,
118119
connector_type, devents, metadata, gui_metadata, complete)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem
8888
tearing_state = nothing, substitutions = nothing,
8989
complete = false, parent = nothing; checks::Union{Bool, Int} = true)
9090
if checks == true || (checks & CheckUnits) > 0
91-
all_dimensionless([states; ps]) || check_units(eqs)
91+
u = __get_unit_type(states, ps)
92+
check_units(u, eqs)
9293
end
9394
new(tag, eqs, states, ps, var_to_name, observed, jac, name, systems, defaults,
9495
connector_type, metadata, gui_metadata, tearing_state, substitutions, complete,

src/systems/optimization/constraints_system.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ struct ConstraintsSystem <: AbstractTimeIndependentSystem
7777
tearing_state = nothing, substitutions = nothing;
7878
checks::Union{Bool, Int} = true)
7979
if checks == true || (checks & CheckUnits) > 0
80-
all_dimensionless([states; ps]) || check_units(constraints)
80+
u = __get_unit_type(states, ps)
81+
check_units(u, constraints)
8182
end
8283
new(tag, constraints, states, ps, var_to_name, observed, jac, name, systems,
8384
defaults,

src/systems/optimization/optimizationsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ struct OptimizationSystem <: AbstractOptimizationSystem
6868
if checks == true || (checks & CheckUnits) > 0
6969
unwrap(op) isa Symbolic && check_units(op)
7070
check_units(observed)
71-
all_dimensionless([states; ps]) || check_units(constraints)
71+
u = __get_unit_type(states, ps)
72+
check_units(u, constraints)
7273
end
7374
new(tag, op, states, ps, var_to_name, observed,
7475
constraints, name, systems, defaults, metadata, gui_metadata, complete,

src/systems/pde/pdesystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ struct PDESystem <: ModelingToolkit.AbstractMultivariateSystem
9898
checks::Union{Bool, Int} = true,
9999
name)
100100
if checks == true || (checks & CheckUnits) > 0
101-
all_dimensionless([dvs; ivs; ps]) || check_units(eqs)
101+
u = __get_unit_type(dvs, ivs, ps)
102+
check_units(u, deqs)
102103
end
103104

104105
eqs = eqs isa Vector ? eqs : [eqs]

src/systems/unit_check.jl

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import DynamicQuantities
1+
import DynamicQuantities, Unitful
22
const DQ = DynamicQuantities
33

44
struct ValidationError <: Exception
@@ -8,14 +8,26 @@ end
88
check_units(::Nothing, _...) = true
99

1010
__get_literal_unit(x) = getmetadata(x, VariableUnit, nothing)
11+
function __get_scalar_unit_type(v)
12+
u = __get_literal_unit(v)
13+
if u isa DQ.AbstractQuantity
14+
return Val(:DynamicQuantities)
15+
elseif u isa Unitful.Unitlike
16+
return Val(:Unitful)
17+
end
18+
return nothing
19+
end
1120
function __get_unit_type(vs′...)
12-
vs = Iterators.flatten(vs′)
13-
for v in vs
14-
u = __get_literal_unit(v)
15-
if u isa DQ.AbstractQuantity
16-
return Val(:Unitful)
21+
for vs in vs′
22+
if vs isa AbstractVector
23+
for v in vs
24+
u = __get_scalar_unit_type(v)
25+
u === nothing || return u
26+
end
1727
else
18-
return Val(:DynamicQuantities)
28+
v = vs
29+
u = __get_scalar_unit_type(v)
30+
u === nothing || return u
1931
end
2032
end
2133
return nothing
@@ -38,5 +50,3 @@ function screen_units(result)
3850
end
3951
end
4052
end
41-
42-

src/systems/validation.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
module UnitfulUnitCheck
22

3-
using .ModelingToolkit, Symbolics, SciMLBase
4-
using .ModelingToolkit: ValidationError
3+
using ..ModelingToolkit, Symbolics, SciMLBase, Unitful, IfElse, RecursiveArrayTools
4+
using ..ModelingToolkit: ValidationError,
5+
ModelingToolkit, Connection, instream, JumpType, VariableUnit, get_systems
6+
using Symbolics: Symbolic, value, issym, isadd, ismul, ispow
57
const MT = ModelingToolkit
68

79
Base.:*(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x * y
@@ -62,11 +64,7 @@ get_literal_unit(x) = screen_unit(getmetadata(x, VariableUnit, unitless))
6264
function get_unit(op, args) # Fallback
6365
result = op(1 .* get_unit.(args)...)
6466
try
65-
if result isa DQ.AbstractQuantity
66-
oneunit(result)
67-
else
68-
unit(result)
69-
end
67+
unit(result)
7068
catch
7169
throw(ValidationError("Unable to get unit for operation $op with arguments $args."))
7270
end
@@ -283,6 +281,5 @@ function MT.check_units(::Val{:Unitful}, eqs...)
283281
validate(eqs...) ||
284282
throw(ValidationError("Some equations had invalid units. See warnings for details."))
285283
end
286-
all_dimensionless(states) = all(x -> safe_get_unit(x, "") in (unitless, nothing), states)
287284

288285
end # module

test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ using SafeTestsets, Test
1313
@safetestset "Clock Test" include("clock.jl")
1414
@safetestset "DiscreteSystem Test" include("discretesystem.jl")
1515
@safetestset "ODESystem Test" include("odesystem.jl")
16-
@safetestset "Dynamic Quantities Test" begin
17-
using DynamicQuantities
18-
include("units.jl")
19-
end
16+
#@safetestset "Dynamic Quantities Test" begin
17+
# using DynamicQuantities
18+
# include("units.jl")
19+
#end
2020
@safetestset "Unitful Quantities Test" begin
2121
using Unitful
2222
include("units.jl")

test/units.jl

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,52 @@
11
using ModelingToolkit, OrdinaryDiffEq, JumpProcesses, IfElse
22
using Test
33
MT = ModelingToolkit
4+
UMT = ModelingToolkit.UnitfulUnitCheck
45
@parameters τ [unit = u"ms"] γ
56
@variables t [unit = u"ms"] E(t) [unit = u"kJ"] P(t) [unit = u"MW"]
67
D = Differential(t)
78

89
#This is how equivalent works:
9-
@test MT.equivalent(u"MW", u"kJ/ms")
10-
@test !MT.equivalent(u"m", u"cm")
11-
@test MT.equivalent(MT.get_unit(P^γ), MT.get_unit((E / τ)^γ))
10+
@test UMT.equivalent(u"MW", u"kJ/ms")
11+
@test !UMT.equivalent(u"m", u"cm")
12+
@test UMT.equivalent(UMT.get_unit(P^γ), UMT.get_unit((E / τ)^γ))
1213

1314
# Basic access
14-
@test MT.get_unit(t) == u"ms"
15-
@test MT.get_unit(E) == u"kJ"
16-
@test MT.get_unit(τ) == u"ms"
17-
@test MT.get_unit(γ) == MT.unitless
18-
@test MT.get_unit(0.5) == MT.unitless
19-
@test MT.get_unit(MT.SciMLBase.NullParameters()) == MT.unitless
15+
@test UMT.get_unit(t) == u"ms"
16+
@test UMT.get_unit(E) == u"kJ"
17+
@test UMT.get_unit(τ) == u"ms"
18+
@test UMT.get_unit(γ) == UMT.unitless
19+
@test UMT.get_unit(0.5) == UMT.unitless
20+
@test UMT.get_unit(UMT.SciMLBase.NullParameters()) == UMT.unitless
2021

2122
# Prohibited unit types
2223
@parameters β [unit = u"°"] α [unit = u"°C"] γ [unit = 1u"s"]
23-
@test_throws MT.ValidationError MT.get_unit(β)
24-
@test_throws MT.ValidationError MT.get_unit(α)
25-
@test_throws MT.ValidationError MT.get_unit(γ)
24+
@test_throws UMT.ValidationError UMT.get_unit(β)
25+
@test_throws UMT.ValidationError UMT.get_unit(α)
26+
@test_throws UMT.ValidationError UMT.get_unit(γ)
2627

2728
# Non-trivial equivalence & operators
28-
@test MT.get_unit^-1) == u"ms^-1"
29-
@test MT.equivalent(MT.get_unit(D(E)), u"MW")
30-
@test MT.equivalent(MT.get_unit(E / τ), u"MW")
31-
@test MT.get_unit(2 * P) == u"MW"
32-
@test MT.get_unit(t / τ) == MT.unitless
33-
@test MT.equivalent(MT.get_unit(P - E / τ), u"MW")
34-
@test MT.equivalent(MT.get_unit(D(D(E))), u"MW/ms")
35-
@test MT.get_unit(IfElse.ifelse(t > t, P, E / τ)) == u"MW"
36-
@test MT.get_unit(1.0^(t / τ)) == MT.unitless
37-
@test MT.get_unit(exp(t / τ)) == MT.unitless
38-
@test MT.get_unit(sin(t / τ)) == MT.unitless
39-
@test MT.get_unit(sin(1u"rad")) == MT.unitless
40-
@test MT.get_unit(t^2) == u"ms^2"
29+
@test UMT.get_unit^-1) == u"ms^-1"
30+
@test UMT.equivalent(UMT.get_unit(D(E)), u"MW")
31+
@test UMT.equivalent(UMT.get_unit(E / τ), u"MW")
32+
@test UMT.get_unit(2 * P) == u"MW"
33+
@test UMT.get_unit(t / τ) == UMT.unitless
34+
@test UMT.equivalent(UMT.get_unit(P - E / τ), u"MW")
35+
@test UMT.equivalent(UMT.get_unit(D(D(E))), u"MW/ms")
36+
@test UMT.get_unit(IfElse.ifelse(t > t, P, E / τ)) == u"MW"
37+
@test UMT.get_unit(1.0^(t / τ)) == UMT.unitless
38+
@test UMT.get_unit(exp(t / τ)) == UMT.unitless
39+
@test UMT.get_unit(sin(t / τ)) == UMT.unitless
40+
@test UMT.get_unit(sin(1u"rad")) == UMT.unitless
41+
@test UMT.get_unit(t^2) == u"ms^2"
4142

4243
eqs = [D(E) ~ P - E / τ
4344
0 ~ P]
44-
@test MT.validate(eqs)
45+
@test UMT.validate(eqs)
4546
@named sys = ODESystem(eqs)
4647

47-
@test !MT.validate(D(D(E)) ~ P)
48-
@test !MT.validate(0 ~ P + E * τ)
48+
@test !UMT.validate(D(D(E)) ~ P)
49+
@test !UMT.validate(0 ~ P + E * τ)
4950

5051
# Disabling unit validation/checks selectively
5152
@test_throws MT.ArgumentError ODESystem(eqs, t, [E, P, t], [τ], name = :sys)
@@ -86,9 +87,9 @@ end
8687
good_eqs = [connect(p1, p2)]
8788
bad_eqs = [connect(p1, p2, op)]
8889
bad_length_eqs = [connect(op, lp)]
89-
@test MT.validate(good_eqs)
90-
@test !MT.validate(bad_eqs)
91-
@test !MT.validate(bad_length_eqs)
90+
@test UMT.validate(good_eqs)
91+
@test !UMT.validate(bad_eqs)
92+
@test !UMT.validate(bad_length_eqs)
9293
@named sys = ODESystem(good_eqs, t, [], [])
9394
@test_throws MT.ValidationError ODESystem(bad_eqs, t, [], []; name = :sys)
9495

@@ -136,7 +137,7 @@ noiseeqs = [0.1u"MW" 0.1u"MW"
136137
# Invalid noise matrix
137138
noiseeqs = [0.1u"MW" 0.1u"MW"
138139
0.1u"MW" 0.1u"s"]
139-
@test !MT.validate(eqs, noiseeqs)
140+
@test !UMT.validate(eqs, noiseeqs)
140141

141142
# Non-trivial simplifications
142143
@variables t [unit = u"s"] V(t) [unit = u"m"^3] L(t) [unit = u"m"]

0 commit comments

Comments
 (0)