Skip to content

Commit 837fb38

Browse files
author
Lucas Morton
committed
Refactor to make it possible to specialize get_unit for custom types & registered functions.
1 parent 16033d8 commit 837fb38

File tree

2 files changed

+96
-113
lines changed

2 files changed

+96
-113
lines changed

src/systems/validation.jl

Lines changed: 90 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,76 +9,103 @@ function screen_unit(result)
99
result isa Unitful.Unitlike || throw(ValidationError("Unit must be a subtype of Unitful.Unitlike, not $(typeof(result))."))
1010
result isa Unitful.ScalarUnits || throw(ValidationError("Non-scalar units such as $result are not supported. Use a scalar unit instead."))
1111
result == u"°" && throw(ValidationError("Degrees are not supported. Use radians instead."))
12-
end
13-
"Find the unit of a symbolic item."
14-
get_unit(x::Real) = unitless
15-
function get_unit(x::Unitful.Quantity)
16-
result = Unitful.unit(x)
17-
screen_unit(result)
18-
return result
12+
result
1913
end
2014
equivalent(x,y) = isequal(1*x,1*y)
2115
unitless = Unitful.unit(1)
2216

17+
#For dispatching get_unit
18+
Literal = Union{Sym,Symbolics.ArrayOp,Symbolics.Arr,Symbolics.CallWithMetadata}
19+
Conditional = Union{typeof(ifelse),typeof(IfElse.ifelse)}
20+
Comparison = Union{typeof(Base.:>), typeof(Base.:<), typeof(==)}
21+
22+
"Find the unit of a symbolic item."
23+
get_unit(x::Real) = unitless
24+
get_unit(x::Unitful.Quantity) = screen_unit(Unitful.unit(x))
2325
get_unit(x::AbstractArray) = map(get_unit,x)
2426
get_unit(x::Num) = get_unit(value(x))
25-
function get_unit(x::Symbolic)
26-
if x isa Sym || operation(x) isa Sym || (operation(x) isa Term && operation(x).f == getindex) || x isa Symbolics.ArrayOp
27-
if x.metadata !== nothing
28-
symunits = get(x.metadata, VariableUnit, unitless)
29-
screen_unit(symunits)
30-
else
31-
symunits = unitless
32-
end
33-
return symunits
34-
elseif operation(x) isa Differential
35-
return get_unit(arguments(x)[1]) / get_unit(operation(x).x)
36-
elseif operation(x) isa Integral
37-
unit = 1
38-
if operation(x).x isa Vector
39-
for u in operation(x).x
40-
unit *= get_unit(u)
41-
end
42-
else
43-
unit *= get_unit(operation(x).x)
44-
end
45-
return get_unit(arguments(x)[1]) * unit
46-
elseif operation(x) isa Difference
47-
return get_unit(arguments(x)[1]) / get_unit(operation(x).t) #TODO: make this same as Differential
48-
elseif x isa Pow
49-
pargs = arguments(x)
50-
base,expon = get_unit.(pargs)
51-
@assert expon isa Unitful.DimensionlessUnits
52-
if base == unitless
53-
unitless
54-
else
55-
pargs[2] isa Number ? operation(x)(base, pargs[2]) : operation(x)(1*base, pargs[2])
56-
end
57-
elseif x isa Add # Cannot simply add the units b/c they may differ in magnitude (eg, kg vs g)
58-
terms = get_unit.(arguments(x))
59-
firstunit = terms[1]
60-
for other in terms[2:end]
61-
termlist = join(map(repr,terms),", ")
62-
equivalent(other,firstunit) || throw(ValidationError(", in sum $x, units [$termlist] do not match."))
63-
end
64-
return firstunit
65-
elseif operation(x) in ( Base.:> , Base.:< , == )
66-
terms = get_unit.(arguments(x))
67-
equivalent(terms[1],terms[2]) || throw(ValidationError(", in comparison $x, units [$(terms[1])] and [$(terms[2])] do not match."))
68-
return unitless
69-
elseif operation(x) == ifelse || operation(x) == IfElse.ifelse
70-
terms = get_unit.(arguments(x))
71-
terms[1] == unitless || throw(ValidationError(", in $x, [$(terms[1])] is not dimensionless."))
72-
equivalent(terms[2],terms[3]) || throw(ValidationError(", in $x, units [$(terms[2])] and [$(terms[3])] do not match."))
73-
return terms[2]
74-
elseif operation(x) == Symbolics._mapreduce
75-
if x.arguments[2] == +
76-
get_unit(x.arguments[3])
77-
else
78-
throw(ValidationError("Unsupported array operation $x"))
27+
get_unit(x::Literal) = screen_unit(getmetadata(x,VariableUnit, unitless))
28+
get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x)
29+
get_unit(op::Difference, args) = get_unit(args[1]) / get_unit(op.t) #why are these not identical?!?
30+
get_unit(op::typeof(getindex),args) = get_unit(args[1])
31+
function get_unit(op,args) #Fallback
32+
result = op(1 .* get_unit.(args)...)
33+
try
34+
unit(result)
35+
catch
36+
throw(ValidationError("Unable to get unit for operation $op with arguments $args."))
37+
end
38+
end
39+
40+
function get_unit(op::Integral,args)
41+
unit = 1
42+
if op.x isa Vector
43+
for u in op.x
44+
unit *= get_unit(u)
7945
end
8046
else
81-
return get_unit(operation(x)(1 .* get_unit.(arguments(x))...))
47+
unit *= get_unit(op.x)
48+
end
49+
return get_unit(args[1]) * unit
50+
end
51+
52+
function get_unit(x::Pow)
53+
pargs = arguments(x)
54+
base,expon = get_unit.(pargs)
55+
@assert expon isa Unitful.DimensionlessUnits
56+
if base == unitless
57+
unitless
58+
else
59+
pargs[2] isa Number ? base^pargs[2] : (1*base)^pargs[2]
60+
end
61+
end
62+
63+
function get_unit(x::Add)
64+
terms = get_unit.(arguments(x))
65+
firstunit = terms[1]
66+
for other in terms[2:end]
67+
termlist = join(map(repr, terms), ", ")
68+
equivalent(other, firstunit) || throw(ValidationError(", in sum $x, units [$termlist] do not match."))
69+
end
70+
return firstunit
71+
end
72+
73+
function get_unit(op::Conditional, args)
74+
terms = get_unit.(args)
75+
terms[1] == unitless || throw(ValidationError(", in $x, [$(terms[1])] is not dimensionless."))
76+
equivalent(terms[2], terms[3]) || throw(ValidationError(", in $x, units [$(terms[2])] and [$(terms[3])] do not match."))
77+
return terms[2]
78+
end
79+
80+
function get_unit(op::typeof(Symbolics._mapreduce),args)
81+
if args[2] == +
82+
get_unit(args[3])
83+
else
84+
throw(ValidationError("Unsupported array operation $op"))
85+
end
86+
end
87+
88+
function get_unit(op::Comparison, args)
89+
terms = get_unit.(args)
90+
equivalent(terms[1], terms[2]) || throw(ValidationError(", in comparison $x, units [$(terms[1])] and [$(terms[2])] do not match."))
91+
return unitless
92+
end
93+
94+
function get_unit(x::Symbolic)
95+
if SymbolicUtils.istree(x)
96+
op = operation(x)
97+
if op isa Sym # Not a real function call, just a dependent variable. Unit is on the Sym.
98+
return screen_unit(getmetadata(x, VariableUnit, unitless))
99+
elseif op isa Term && !(operation(op) isa Term) #
100+
gp = getmetadata(x,Symbolics.GetindexParent,nothing)
101+
return screen_unit(getmetadata(gp, VariableUnit, unitless))
102+
elseif op isa Term
103+
return screen_unit(getmetadata(x, VariableUnit, unitless))
104+
end
105+
args = arguments(x)
106+
return get_unit(op, args)
107+
else #This function should only be reached by Terms, for which `istree` is true
108+
throw(ArgumentError("Unsupported value $x."))
82109
end
83110
end
84111

@@ -92,7 +119,7 @@ function safe_get_unit(term, info)
92119
@warn("$info: $(err.x) and $(err.y) are not dimensionally compatible.")
93120
elseif err isa ValidationError
94121
@warn(info*err.message)
95-
elseif err isa MethodError
122+
elseif err isa MethodError #Warning: Unable to get unit for operation x[1] with arguments SymbolicUtils.Sym{Real, Base.ImmutableDict{DataType, Any}}[t].
96123
@warn("$info: no method matching $(err.f) for arguments $(typeof.(err.args)).")
97124
else
98125
rethrow()

test/units.jl

Lines changed: 6 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -52,45 +52,13 @@ eqs = [D(E) ~ P - E/τ
5252
eqs = [0 ~ σ*(y - x)]
5353
@test MT.validate(eqs)
5454

55-
#Array variables
56-
@variables t x[1:3,1:3](t)
55+
##Array variables
56+
@variables t [unit = u"s"] x[1:3](t) [unit = u"m"]
57+
@parameters v[1:3] = [1,2,3] [unit = u"m/s"]
5758
D = Differential(t)
58-
eqs = D.(x) .~ x
59+
eqs = D.(x) .~ v
5960
ODESystem(eqs,name=:sys)
6061

61-
# Array ops
62-
using Symbolics: unwrap, wrap
63-
using LinearAlgebra
64-
@variables t
65-
sts = @variables x[1:3](t) y(t)
66-
ps = @parameters p[1:3] = [1, 2, 3]
67-
D = Differential(t)
68-
eqs = [
69-
collect(D.(x) ~ x)
70-
D(y) ~ norm(x)*y
71-
]
72-
ODESystem(eqs, t, [sts...;], [ps...;],name=:sys)
73-
74-
#= Not supported yet b/c iterate doesn't work on unitful array
75-
# Array ops with units
76-
@variables t [unit =u"s"]
77-
sts = @variables x[1:3](t) [unit = u"kg"] y(t) [unit = u"kg"]
78-
ps = @parameters b [unit = u"s"^-1]
79-
D = Differential(t)
80-
eqs = [
81-
collect(D.(x) ~ b*x)
82-
D(y) ~ b*norm(x)
83-
]
84-
ODESystem(eqs, t, [sts...;], [ps...;])
85-
86-
#Array variables with units
87-
@variables t [unit = u"s"] x[1:3,1:3](t) [unit = u"kg"]
88-
@parameters a [unit = u"s"^-1]
89-
D = Differential(t)
90-
eqs = D.(x) .~ a*x
91-
ODESystem(eqs)
92-
=#
93-
9462
#Difference equation with units
9563
@parameters t [unit = u"s"] a [unit = u"s"^-1]
9664
@variables x(t) [unit = u"kg"]
@@ -99,7 +67,7 @@ D = Difference(t; dt = 0.1u"s")
9967
eqs = [
10068
δ(x) ~ a*x
10169
]
102-
de = ODESystem(eqs, t, [x, y], [a],name=:sys)
70+
de = ODESystem(eqs, t, [x], [a],name=:sys)
10371

10472

10573
@parameters t
@@ -202,20 +170,8 @@ maj2 = MassActionJump(γ, [S => 1], [S => -1])
202170
@parameters t
203171
vars = @variables x(t)
204172
D = Differential(t)
205-
eqs =
206-
[
173+
eqs = [
207174
D(x) ~ IfElse.ifelse(t>0.1,2,1)
208175
]
209176
@named sys = ODESystem(eqs, t, vars, [])
210177

211-
#Vectors of symbols
212-
@parameters t
213-
@register dummy(vector::Vector{Num}, scalar)
214-
dummy(vector, scalar) = vector[1] .- scalar
215-
216-
@variables vec[1:2](t)
217-
vec = collect(vec)
218-
eqs = [vec .~ dummy(vec, vec[1]);]
219-
sts = vcat(vec)
220-
ODESystem(eqs, t, [sts...;], [], name=:sys)
221-

0 commit comments

Comments
 (0)