Skip to content

Commit 4f51c0f

Browse files
Merge pull request #2621 from SciML/myb/nonSI
Add support for symbolic units by normalizing them to SI units
2 parents 3fe0d7e + ba781bc commit 4f51c0f

File tree

2 files changed

+46
-20
lines changed

2 files changed

+46
-20
lines changed

src/systems/unit_check.jl

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,11 @@ function screen_unit(result)
4848
if result isa DQ.AbstractQuantity
4949
d = DQ.dimension(result)
5050
if d isa DQ.Dimensions
51-
if result != oneunit(result)
52-
throw(ValidationError("$result uses non SI unit. Please use SI unit only."))
53-
end
5451
return result
5552
elseif d isa DQ.SymbolicDimensions
56-
throw(ValidationError("$result uses SymbolicDimensions, please use `u\"m\"` to instantiate SI unit only."))
53+
return DQ.uexpand(oneunit(result))
5754
else
58-
throw(ValidationError("$result doesn't use SI unit, please use `u\"m\"` to instantiate SI unit only."))
55+
throw(ValidationError("$result doesn't have a recognized unit"))
5956
end
6057
else
6158
throw(ValidationError("$result doesn't have any unit."))
@@ -69,7 +66,7 @@ get_literal_unit(x) = screen_unit(something(__get_literal_unit(x), unitless))
6966
Find the unit of a symbolic item.
7067
"""
7168
get_unit(x::Real) = unitless
72-
get_unit(x::DQ.AbstractQuantity) = screen_unit(oneunit(x))
69+
get_unit(x::DQ.AbstractQuantity) = screen_unit(x)
7370
get_unit(x::AbstractArray) = map(get_unit, x)
7471
get_unit(x::Num) = get_unit(unwrap(x))
7572
get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x)
@@ -81,12 +78,19 @@ get_unit(op::typeof(instream), args) = get_unit(args[1])
8178
function get_unit(op, args) # Fallback
8279
result = op(get_unit.(args)...)
8380
try
84-
oneunit(result)
81+
result
8582
catch
8683
throw(ValidationError("Unable to get unit for operation $op with arguments $args."))
8784
end
8885
end
8986

87+
function get_unit(::Union{typeof(+), typeof(-)}, args)
88+
u = get_unit(args[1])
89+
if all(i -> get_unit(args[i]) == u, 2:length(args))
90+
return u
91+
end
92+
end
93+
9094
function get_unit(op::Integral, args)
9195
unit = 1
9296
if op.domain.variables isa Vector
@@ -96,7 +100,7 @@ function get_unit(op::Integral, args)
96100
else
97101
unit *= get_unit(op.domain.variables)
98102
end
99-
return oneunit(get_unit(args[1]) * unit)
103+
return get_unit(args[1]) * unit
100104
end
101105

102106
equivalent(x, y) = isequal(x, y)
@@ -197,7 +201,11 @@ function _validate(terms::Vector, labels::Vector{String}; info::String = "")
197201
first_label = label
198202
elseif !equivalent(first_unit, equnit)
199203
valid = false
200-
@warn("$info: units [$(first_unit)] for $(first_label) and [$(equnit)] for $(label) do not match.")
204+
str = "$info: units [$(first_unit)] for $(first_label) and [$(equnit)] for $(label) do not match."
205+
if oneunit(first_unit) == oneunit(equnit)
206+
str *= " If there are non-SI units in the system, please use symbolic units like `us\"ms\"`"
207+
end
208+
@warn(str)
201209
end
202210
end
203211
end
@@ -227,7 +235,11 @@ function _validate(conn::Connection; info::String = "")
227235
bunit = safe_get_unit(sst[j], info * string(nameof(s)) * "#$j")
228236
if !equivalent(aunit, bunit)
229237
valid = false
230-
@warn("$info: connected system unknowns $x and $(sst[j]) have mismatched units.")
238+
str = "$info: connected system unknowns $x ($aunit) and $(sst[j]) ($bunit) have mismatched units."
239+
if oneunit(aunit) == oneunit(bunit)
240+
str *= " If there are non-SI units in the system, please use symbolic units like `us\"ms\"`"
241+
end
242+
@warn(str)
231243
end
232244
end
233245
end

test/dq_units.jl

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ using ModelingToolkit: t, D
1313
@test MT.get_unit(0.5) == MT.unitless
1414
@test MT.get_unit(MT.SciMLBase.NullParameters()) == MT.unitless
1515

16-
# Prohibited unit types
17-
@parameters γ [unit = 1u"ms"]
18-
@test_throws MT.ValidationError MT.get_unit(γ)
19-
2016
eqs = [D(E) ~ P - E / τ
2117
0 ~ P]
2218
@test MT.validate(eqs)
@@ -59,11 +55,18 @@ end
5955
end
6056
@named p1 = Pin()
6157
@named p2 = Pin()
62-
@test_throws MT.ValidationError @named op = OtherPin()
6358
@named lp = LongPin()
6459
good_eqs = [connect(p1, p2)]
6560
@test MT.validate(good_eqs)
6661
@named sys = ODESystem(good_eqs, t, [], [])
62+
@named op = OtherPin()
63+
bad_eqs = [connect(p1, op)]
64+
@test !MT.validate(bad_eqs)
65+
@test_throws MT.ValidationError @named sys = ODESystem(bad_eqs, t, [], [])
66+
@named op2 = OtherPin()
67+
good_eqs = [connect(op, op2)]
68+
@test MT.validate(good_eqs)
69+
@named sys = ODESystem(good_eqs, t, [], [])
6770

6871
# Array variables
6972
@variables x(t)[1:3] [unit = u"m"]
@@ -85,18 +88,22 @@ eqs = [
8588
eqs = [D(E) ~ P - E / τ
8689
P ~ Q]
8790

91+
noiseeqs = [0.1us"W",
92+
0.1us"W"]
93+
@named sys = SDESystem(eqs, noiseeqs, t, [P, E], [τ, Q])
94+
8895
noiseeqs = [0.1u"W",
8996
0.1u"W"]
90-
@named sys = SDESystem(eqs, noiseeqs, t, [P, E], [τ, Q])
97+
@test_throws MT.ValidationError @named sys = SDESystem(eqs, noiseeqs, t, [P, E], [τ, Q])
9198

9299
# With noise matrix
93-
noiseeqs = [0.1u"W" 0.1u"W"
94-
0.1u"W" 0.1u"W"]
100+
noiseeqs = [0.1us"W" 0.1us"W"
101+
0.1us"W" 0.1us"W"]
95102
@named sys = SDESystem(eqs, noiseeqs, t, [P, E], [τ, Q])
96103

97104
# Invalid noise matrix
98-
noiseeqs = [0.1u"W" 0.1u"W"
99-
0.1u"W" 0.1u"s"]
105+
noiseeqs = [0.1us"W" 0.1us"W"
106+
0.1us"W" 0.1us"s"]
100107
@test !MT.validate(eqs, noiseeqs)
101108

102109
# Non-trivial simplifications
@@ -210,3 +217,10 @@ end
210217
@test prob = ODEProblem(
211218
pend, u0, (0.0, 1.0), p; guesses = guess, check_units = false) isa Any
212219
end
220+
221+
@parameters p [unit = u"L/s"] d [unit = u"s^(-1)"]
222+
@parameters tt [unit = u"s"]
223+
@variables X(tt) [unit = u"L"]
224+
DD = Differential(tt)
225+
eqs = [DD(X) ~ p - d * X + d * X]
226+
@test ModelingToolkit.validate(eqs)

0 commit comments

Comments
 (0)