|
1 | 1 | Base.:*(x::Union{Num,Symbolic},y::Unitful.AbstractQuantity) = x * y
|
2 | 2 |
|
3 |
| - |
4 |
| -function vartype(x::Symbolic) |
5 |
| - if !(x.metadata isa Nothing) |
6 |
| - return haskey(x.metadata, VariableUnit) ? x.metadata[VariableUnit] : 1 |
7 |
| - end |
8 |
| - 1 |
9 |
| -end |
10 |
| -vartype(x::Num) = vartype(value(x)) |
11 |
| - |
12 |
| -instantiate(x) = 1 |
13 |
| -instantiate(x::Unitful.Quantity) = 1 * Unitful.unit(x) |
14 |
| -instantiate(x::Num) = instantiate(value(x)) |
15 |
| -function instantiate(x::Symbolic) |
| 3 | +"Find the units of a symbolic item." |
| 4 | +get_units(x) = 1 |
| 5 | +get_units(x::Unitful.Quantity) = 1 * Unitful.unit(x) |
| 6 | +get_units(x::Num) = get_units(value(x)) |
| 7 | +function get_units(x::Symbolic) |
16 | 8 | vx = value(x)
|
17 | 9 | if vx isa Sym || operation(vx) isa Sym || (operation(vx) isa Term && operation(vx).f == getindex) || vx isa Symbolics.ArrayOp
|
18 |
| - return oneunit(1 * vartype(vx)) |
| 10 | + if x.metadata !== nothing |
| 11 | + symunits = haskey(x.metadata, VariableUnit) ? x.metadata[VariableUnit] : 1 |
| 12 | + else |
| 13 | + symunits = 1 |
| 14 | + end |
| 15 | + return oneunit(1 * symunits) |
19 | 16 | elseif operation(vx) isa Differential || operation(vx) isa Difference
|
20 |
| - return instantiate(arguments(vx)[1]) / instantiate(arguments(arguments(vx)[1])[1]) |
| 17 | + return get_units(arguments(vx)[1]) / get_units(arguments(arguments(vx)[1])[1]) |
21 | 18 | elseif vx isa Pow
|
22 | 19 | pargs = arguments(vx)
|
23 |
| - base,expon = instantiate.(pargs) |
| 20 | + base,expon = get_units.(pargs) |
24 | 21 | uconvert(NoUnits, expon) # This acts as an assertion
|
25 | 22 | return base == 1 ? 1 : operation(vx)(base, pargs[2])
|
26 | 23 | elseif vx isa Add # Cannot simply add the units b/c they may differ in magnitude (eg, kg vs g)
|
27 |
| - terms = instantiate.(arguments(vx)) |
| 24 | + terms = get_units.(arguments(vx)) |
28 | 25 | firstunit = unit(terms[1])
|
29 | 26 | @assert all(map(x -> ustrip(firstunit, x) == 1, terms[2:end]))
|
30 | 27 | return 1 * firstunit
|
31 | 28 | elseif operation(vx) == Symbolics._mapreduce
|
32 | 29 | if vx.arguments[2] == +
|
33 |
| - instantiate(vx.arguments[3]) |
| 30 | + get_units(vx.arguments[3]) |
34 | 31 | else
|
35 | 32 | throw(ArgumentError("Unknown array operation $vx"))
|
36 | 33 | end
|
37 | 34 | else
|
38 |
| - return oneunit(operation(vx)(instantiate.(arguments(vx))...)) |
| 35 | + return oneunit(operation(vx)(get_units.(arguments(vx))...)) |
39 | 36 | end
|
40 | 37 | end
|
41 | 38 |
|
42 |
| -function validate(eq::ModelingToolkit.Equation; eqnum = 1) |
43 |
| - lhs = rhs = nothing |
| 39 | +"Get units of term, returning nothing & showing warning instead of throwing errors." |
| 40 | +function safe_get_units(term, info) |
| 41 | + side = nothing |
44 | 42 | try
|
45 |
| - lhs = instantiate(eq.lhs) |
| 43 | + side = get_units(term) |
46 | 44 | catch err
|
47 | 45 | if err isa Unitful.DimensionError
|
48 |
| - @warn("In left-hand side of eq. #$eqnum: $(eq.lhs), $(err.x) and $(err.y) are not dimensionally compatible.") |
49 |
| - elseif err isa MethodError |
50 |
| - @warn("In left-hand side of eq. #$eqnum: $(err.f) doesn't accept $(err.args).") |
| 46 | + @warn("$info: $(err.x) and $(err.y) are not dimensionally compatible.") |
| 47 | + elseif err isa MethodError #TODO: filter for only instances where the arguments are unitful |
| 48 | + @warn("$info: no method matching $(err.f) for arguments $(err.args).") |
51 | 49 | else
|
52 | 50 | rethrow()
|
53 | 51 | end
|
54 | 52 | end
|
55 |
| - try |
56 |
| - rhs = instantiate(eq.rhs) |
57 |
| - catch err |
58 |
| - if err isa Unitful.DimensionError |
59 |
| - @warn("In right-hand side of eq. #$eqnum: $(eq.rhs), $(err.x) and $(err.y) are not dimensionally compatible.") |
60 |
| - elseif err isa MethodError |
61 |
| - @warn("In right-hand side of eq. #$eqnum: $(err.f) doesn't accept $(err.args).") |
62 |
| - else |
63 |
| - rethrow() |
64 |
| - end |
65 |
| - end |
66 |
| - if (rhs !== nothing) && (lhs !== nothing) |
67 |
| - if !isequal(lhs, rhs) |
68 |
| - @warn("In eq. #$eqnum, left-side units ($lhs) and right-side units ($rhs) don't match.") |
| 53 | + side |
| 54 | +end |
| 55 | + |
| 56 | +function _validate(terms::Vector,labels::Vector; info::String = "") |
| 57 | + equnits = safe_get_units.(terms,info.*labels) |
| 58 | + allthere = all(map(x->x!==nothing,equnits)) |
| 59 | + allmatching = true |
| 60 | + if allthere |
| 61 | + for idx in 2:length(equnits) |
| 62 | + if !isequal(equnits[1],equnits[idx]) |
| 63 | + allmatching = false |
| 64 | + @warn("$info: units $(equnits[1]) for $(labels[1]) and $(equnits[idx]) for $(labels[idx]) do not match.") |
| 65 | + end |
69 | 66 | end
|
70 | 67 | end
|
71 |
| - (rhs !== nothing) && (lhs !== nothing) && isequal(lhs, rhs) |
| 68 | + allthere && allmatching |
| 69 | +end |
| 70 | + |
| 71 | +function validate(eq::ModelingToolkit.Equation; info::String = "") |
| 72 | + labels = ["left-hand side", "right-hand side"] |
| 73 | + terms = [eq.lhs,eq.rhs] |
| 74 | + _validate(terms,labels,info = info) |
| 75 | +end |
| 76 | + |
| 77 | +function validate(eq::ModelingToolkit.Equation,noiseterm; info::String = "") |
| 78 | + labels = ["left-hand side", "right-hand side","noise term"] |
| 79 | + terms = [eq.lhs,eq.rhs,noiseterm] |
| 80 | + _validate(terms,labels,info = info) |
| 81 | +end |
| 82 | + |
| 83 | +function validate(eq::ModelingToolkit.Equation,noisevec::Vector; info::String = "") |
| 84 | + labels = vcat(["left-hand side", "right-hand side"],"noise term #".* string.(1:length(noisevec))) |
| 85 | + terms = vcat([eq.lhs,eq.rhs],noisevec) |
| 86 | + _validate(terms,labels,info = info) |
72 | 87 | end
|
73 | 88 |
|
74 | 89 | function validate(eqs::Vector{ModelingToolkit.Equation})
|
75 |
| - correct = [validate(eqs[idx],eqnum=idx) for idx in 1:length(eqs)] |
76 |
| - all(correct) || throw(ArgumentError("Invalid equations, see warnings for details.")) |
| 90 | + all([validate(eqs[idx],info = "In eq. #$idx") for idx in 1:length(eqs)]) |
| 91 | +end |
| 92 | + |
| 93 | +function validate(eqs::Vector{ModelingToolkit.Equation},noise::Vector) |
| 94 | + all([validate(eqs[idx],noise[idx],info = "In eq. #$idx") for idx in 1:length(eqs)]) |
77 | 95 | end
|
78 | 96 |
|
79 |
| -validate(sys::AbstractODESystem) = validate(equations(sys)) |
| 97 | +function validate(eqs::Vector{ModelingToolkit.Equation},noise::Matrix) |
| 98 | + all([validate(eqs[idx],noise[idx,:],info = "In eq. #$idx") for idx in 1:length(eqs)]) |
| 99 | +end |
| 100 | + |
| 101 | +"Returns true iff units of equations are valid." |
| 102 | +validate(eqs::Vector) = validate(convert.(ModelingToolkit.Equation,eqs)) |
| 103 | + |
| 104 | +"Throws error if units of equations are invalid." |
| 105 | +check_units(eqs...) = validate(eqs...) || throw(ArgumentError("Some equations had invalid units. See warnings for details.")) |
0 commit comments