Skip to content

Commit 93c1ec2

Browse files
Merge pull request #1229 from lamorton/unit_convert
Refactor `get_unit`
2 parents 16033d8 + 4c220c6 commit 93c1ec2

File tree

3 files changed

+167
-161
lines changed

3 files changed

+167
-161
lines changed

docs/src/basics/Validation.md

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,24 @@ Units may assigned with the following syntax.
99
```julia
1010
using ModelingToolkit, Unitful
1111
@variables t [unit = u"s"] x(t) [unit = u"m"] g(t) w(t) [unit = "Hz"]
12-
#Or,
12+
1313
@variables(t, [unit = u"s"], x(t), [unit = u"m"], g(t), w(t), [unit = "Hz"])
14-
#Or,
14+
1515
@variables(begin
1616
t, [unit = u"s"],
1717
x(t), [unit = u"m"],
1818
g(t),
1919
w(t), [unit = "Hz"]
2020
end)
21+
22+
# Simultaneously set default value (use plain numbers, not quantities)
23+
@variable x=10 [unit = u"m"]
24+
25+
# Symbolic array: unit applies to all elements
26+
@variable x[1:3] [unit = u"m"]
2127
```
2228

23-
Do not use `quantities` such as `1u"s"` or `1/u"s"` or `u"1/s"` as these will result in errors; instead use `u"s"` or `u"s^1"`.
29+
Do not use `quantities` such as `1u"s"`, `1/u"s"` or `u"1/s"` as these will result in errors; instead use `u"s"`, `u"s^-1"`, or `u"s"^-1`.
2430

2531
## Unit Validation & Inspection
2632

@@ -62,8 +68,38 @@ eqs = eqs = [D(E) ~ P - E/τ,
6268
0 ~ P ]
6369
ModelingToolkit.validate(eqs) #Returns false while displaying a warning message
6470
```
71+
## User-Defined Registered Functions and Types
72+
73+
In order to validate user-defined types and `register`ed functions, specialize `get_unit`. Single-parameter calls to `get_unit`
74+
expect an object type, while two-parameter calls expect a function type as the first argument, and a vector of arguments as the
75+
second argument.
76+
77+
```julia
78+
using ModelingToolkit
79+
# Composite type parameter in registered function
80+
@parameters t
81+
D = Differential(t)
82+
struct NewType
83+
f
84+
end
85+
@register dummycomplex(complex::Num, scalar)
86+
dummycomplex(complex, scalar) = complex.f - scalar
87+
88+
c = NewType(1)
89+
MT.get_unit(x::NewType) = MT.get_unit(x.f)
90+
function MT.get_unit(op::typeof(dummycomplex),args)
91+
argunits = MT.get_unit.(args)
92+
MT.get_unit(-,args)
93+
end
94+
95+
sts = @variables a(t)=0 [unit = u"cm"]
96+
ps = @parameters s=-1 [unit = u"cm"] c=c [unit = u"cm"]
97+
eqs = [D(a) ~ dummycomplex(c, s);]
98+
sys = ODESystem(eqs, t, [sts...;], [ps...;], name=:sys)
99+
sys_simple = structural_simplify(sys)
100+
```
65101

66-
## `Unitful` Literals & User-Defined Functions
102+
## `Unitful` Literals
67103

68104
In order for a function to work correctly during both validation & execution, the function must be unit-agnostic. That is, no unitful literals may be used. Any unitful quantity must either be a `parameter` or `variable`. For example, these equations will not validate successfully.
69105

src/systems/validation.jl

Lines changed: 101 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,80 +5,119 @@ struct ValidationError <: Exception
55
message::String
66
end
77

8+
"Throw exception on invalid unit types, otherwise return argument."
89
function screen_unit(result)
910
result isa Unitful.Unitlike || throw(ValidationError("Unit must be a subtype of Unitful.Unitlike, not $(typeof(result))."))
1011
result isa Unitful.ScalarUnits || throw(ValidationError("Non-scalar units such as $result are not supported. Use a scalar unit instead."))
1112
result == u"°" && throw(ValidationError("Degrees are not supported. Use radians instead."))
13+
result
1214
end
13-
"Find the unit of a symbolic item."
14-
get_unit(x::Real) = unitless
15-
function get_unit(x::Unitful.Quantity)
16-
result = Unitful.unit(x)
17-
screen_unit(result)
18-
return result
19-
end
15+
16+
"""Test unit equivalence.
17+
18+
Example of implemented behavior:
19+
```julia
20+
using ModelingToolkit, Unitful
21+
MT = ModelingToolkit
22+
@parameters γ P [unit = u"MW"] E [unit = u"kJ"] τ [unit = u"ms"]
23+
@test MT.equivalent(u"MW" ,u"kJ/ms") # Understands prefixes
24+
@test !MT.equivalent(u"m", u"cm") # Units must be same magnitude
25+
@test MT.equivalent(MT.get_unit(P^γ), MT.get_unit((E/τ)^γ)) # Handles symbolic exponents
26+
```
27+
"""
2028
equivalent(x,y) = isequal(1*x,1*y)
2129
unitless = Unitful.unit(1)
2230

31+
#For dispatching get_unit
32+
Literal = Union{Sym,Symbolics.ArrayOp,Symbolics.Arr,Symbolics.CallWithMetadata}
33+
Conditional = Union{typeof(ifelse),typeof(IfElse.ifelse)}
34+
Comparison = Union{typeof.([==, !=, , <, <=, , >, >=, ])...}
35+
36+
"Find the unit of a symbolic item."
37+
get_unit(x::Real) = unitless
38+
get_unit(x::Unitful.Quantity) = screen_unit(Unitful.unit(x))
2339
get_unit(x::AbstractArray) = map(get_unit,x)
2440
get_unit(x::Num) = get_unit(value(x))
25-
function get_unit(x::Symbolic)
26-
if x isa Sym || operation(x) isa Sym || (operation(x) isa Term && operation(x).f == getindex) || x isa Symbolics.ArrayOp
27-
if x.metadata !== nothing
28-
symunits = get(x.metadata, VariableUnit, unitless)
29-
screen_unit(symunits)
30-
else
31-
symunits = unitless
32-
end
33-
return symunits
34-
elseif operation(x) isa Differential
35-
return get_unit(arguments(x)[1]) / get_unit(operation(x).x)
36-
elseif operation(x) isa Integral
37-
unit = 1
38-
if operation(x).x isa Vector
39-
for u in operation(x).x
40-
unit *= get_unit(u)
41-
end
42-
else
43-
unit *= get_unit(operation(x).x)
44-
end
45-
return get_unit(arguments(x)[1]) * unit
46-
elseif operation(x) isa Difference
47-
return get_unit(arguments(x)[1]) / get_unit(operation(x).t) #TODO: make this same as Differential
48-
elseif x isa Pow
49-
pargs = arguments(x)
50-
base,expon = get_unit.(pargs)
51-
@assert expon isa Unitful.DimensionlessUnits
52-
if base == unitless
53-
unitless
54-
else
55-
pargs[2] isa Number ? operation(x)(base, pargs[2]) : operation(x)(1*base, pargs[2])
56-
end
57-
elseif x isa Add # Cannot simply add the units b/c they may differ in magnitude (eg, kg vs g)
58-
terms = get_unit.(arguments(x))
59-
firstunit = terms[1]
60-
for other in terms[2:end]
61-
termlist = join(map(repr,terms),", ")
62-
equivalent(other,firstunit) || throw(ValidationError(", in sum $x, units [$termlist] do not match."))
63-
end
64-
return firstunit
65-
elseif operation(x) in ( Base.:> , Base.:< , == )
66-
terms = get_unit.(arguments(x))
67-
equivalent(terms[1],terms[2]) || throw(ValidationError(", in comparison $x, units [$(terms[1])] and [$(terms[2])] do not match."))
68-
return unitless
69-
elseif operation(x) == ifelse || operation(x) == IfElse.ifelse
70-
terms = get_unit.(arguments(x))
71-
terms[1] == unitless || throw(ValidationError(", in $x, [$(terms[1])] is not dimensionless."))
72-
equivalent(terms[2],terms[3]) || throw(ValidationError(", in $x, units [$(terms[2])] and [$(terms[3])] do not match."))
73-
return terms[2]
74-
elseif operation(x) == Symbolics._mapreduce
75-
if x.arguments[2] == +
76-
get_unit(x.arguments[3])
77-
else
78-
throw(ValidationError("Unsupported array operation $x"))
41+
get_unit(x::Literal) = screen_unit(getmetadata(x,VariableUnit, unitless))
42+
get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x)
43+
get_unit(op::Difference, args) = get_unit(args[1]) / get_unit(op.t)
44+
get_unit(op::typeof(getindex),args) = get_unit(args[1])
45+
function get_unit(op,args) # Fallback
46+
result = op(1 .* get_unit.(args)...)
47+
try
48+
unit(result)
49+
catch
50+
throw(ValidationError("Unable to get unit for operation $op with arguments $args."))
51+
end
52+
end
53+
54+
function get_unit(op::Integral,args)
55+
unit = 1
56+
if op.x isa Vector
57+
for u in op.x
58+
unit *= get_unit(u)
7959
end
8060
else
81-
return get_unit(operation(x)(1 .* get_unit.(arguments(x))...))
61+
unit *= get_unit(op.x)
62+
end
63+
return get_unit(args[1]) * unit
64+
end
65+
66+
function get_unit(x::Pow)
67+
pargs = arguments(x)
68+
base,expon = get_unit.(pargs)
69+
@assert expon isa Unitful.DimensionlessUnits
70+
if base == unitless
71+
unitless
72+
else
73+
pargs[2] isa Number ? base^pargs[2] : (1*base)^pargs[2]
74+
end
75+
end
76+
77+
function get_unit(x::Add)
78+
terms = get_unit.(arguments(x))
79+
firstunit = terms[1]
80+
for other in terms[2:end]
81+
termlist = join(map(repr, terms), ", ")
82+
equivalent(other, firstunit) || throw(ValidationError(", in sum $x, units [$termlist] do not match."))
83+
end
84+
return firstunit
85+
end
86+
87+
function get_unit(op::Conditional, args)
88+
terms = get_unit.(args)
89+
terms[1] == unitless || throw(ValidationError(", in $x, [$(terms[1])] is not dimensionless."))
90+
equivalent(terms[2], terms[3]) || throw(ValidationError(", in $x, units [$(terms[2])] and [$(terms[3])] do not match."))
91+
return terms[2]
92+
end
93+
94+
function get_unit(op::typeof(Symbolics._mapreduce),args)
95+
if args[2] == +
96+
get_unit(args[3])
97+
else
98+
throw(ValidationError("Unsupported array operation $op"))
99+
end
100+
end
101+
102+
function get_unit(op::Comparison, args)
103+
terms = get_unit.(args)
104+
equivalent(terms[1], terms[2]) || throw(ValidationError(", in comparison $op, units [$(terms[1])] and [$(terms[2])] do not match."))
105+
return unitless
106+
end
107+
108+
function get_unit(x::Symbolic)
109+
if SymbolicUtils.istree(x)
110+
op = operation(x)
111+
if op isa Sym || (op isa Term && operation(op) isa Term) # Dependent variables, not function calls
112+
return screen_unit(getmetadata(x, VariableUnit, unitless)) # Like x(t) or x[i]
113+
elseif op isa Term && !(operation(op) isa Term)
114+
gp = getmetadata(x,Symbolics.GetindexParent,nothing) # Like x[1](t)
115+
return screen_unit(getmetadata(gp, VariableUnit, unitless))
116+
end # Actual function calls:
117+
args = arguments(x)
118+
return get_unit(op, args)
119+
else # This function should only be reached by Terms, for which `istree` is true
120+
throw(ArgumentError("Unsupported value $x."))
82121
end
83122
end
84123

0 commit comments

Comments
 (0)