Skip to content

Commit 91ba532

Browse files
committed
Removed vx b/c should never need to unwrap a Num inside this function. Also, more idiomatic getindex with default.
1 parent d4b508f commit 91ba532

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

src/systems/validation.jl

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,33 @@ get_units(x::Real) = 1
55
get_units(x::Unitful.Quantity) = 1 * Unitful.unit(x)
66
get_units(x::Num) = get_units(value(x))
77
function get_units(x::Symbolic)
8-
vx = value(x)
9-
if vx isa Sym || operation(vx) isa Sym || (operation(vx) isa Term && operation(vx).f == getindex) || vx isa Symbolics.ArrayOp
8+
if x isa Sym || operation(x) isa Sym || (operation(x) isa Term && operation(x).f == getindex) || x isa Symbolics.ArrayOp
109
if x.metadata !== nothing
11-
symunits = haskey(x.metadata, VariableUnit) ? x.metadata[VariableUnit] : 1
10+
symunits = get(x.metadata, VariableUnit, 1)
1211
else
1312
symunits = 1
1413
end
1514
return oneunit(1 * symunits)
16-
elseif operation(vx) isa Differential || operation(vx) isa Difference
17-
return get_units(arguments(vx)[1]) / get_units(arguments(arguments(vx)[1])[1])
18-
elseif vx isa Pow
19-
pargs = arguments(vx)
15+
elseif operation(x) isa Differential || operation(x) isa Difference
16+
return get_units(arguments(x)[1]) / get_units(arguments(arguments(x)[1])[1])
17+
elseif x isa Pow
18+
pargs = arguments(x)
2019
base,expon = get_units.(pargs)
2120
uconvert(NoUnits, expon) # This acts as an assertion
22-
return base == 1 ? 1 : operation(vx)(base, pargs[2])
23-
elseif vx isa Add # Cannot simply add the units b/c they may differ in magnitude (eg, kg vs g)
24-
terms = get_units.(arguments(vx))
21+
return base == 1 ? 1 : operation(x)(base, pargs[2])
22+
elseif x isa Add # Cannot simply add the units b/c they may differ in magnitude (eg, kg vs g)
23+
terms = get_units.(arguments(x))
2524
firstunit = unit(terms[1])
2625
@assert all(map(x -> ustrip(firstunit, x) == 1, terms[2:end]))
2726
return 1 * firstunit
28-
elseif operation(vx) == Symbolics._mapreduce
29-
if vx.arguments[2] == +
30-
get_units(vx.arguments[3])
27+
elseif operation(x) == Symbolics._mapreduce
28+
if x.arguments[2] == +
29+
get_units(x.arguments[3])
3130
else
32-
throw(ArgumentError("Unknown array operation $vx"))
31+
throw(ArgumentError("Unknown array operation $x"))
3332
end
3433
else
35-
return oneunit(operation(vx)(get_units.(arguments(vx))...))
34+
return oneunit(operation(x)(get_units.(arguments(x))...))
3635
end
3736
end
3837

@@ -88,10 +87,10 @@ function validate(jumps::ArrayPartition{<:Union{Any, Vector{<:JumpType}}}, t::Sy
8887
all([validate(jumps.x[idx], t, info = labels[idx]) for idx in 1:3])
8988
end
9089

91-
validate(eq::ModelingToolkit.Reaction; info::String = "") = _validate([oderatelaw(eq)],["",], info = info)
92-
validate(eq::ModelingToolkit.Equation; info::String = "") = _validate([eq.lhs, eq.rhs],["left", "right"],info = info)
93-
validate(eq::ModelingToolkit.Equation, term::Union{Symbolic,Unitful.Quantity,Num}; info::String = "") = _validate([eq.lhs, eq.rhs, term],["left","right","noise"],info = info)
94-
validate(eq::ModelingToolkit.Equation, terms::Vector; info::String = "") = _validate(vcat([eq.lhs, eq.rhs], terms),vcat(["left", "right"], "noise #".*string.(1:length(terms))), info = info)
90+
validate(eq::ModelingToolkit.Reaction; info::String = "") = _validate([oderatelaw(eq)], ["",], info = info)
91+
validate(eq::ModelingToolkit.Equation; info::String = "") = _validate([eq.lhs, eq.rhs], ["left", "right"], info = info)
92+
validate(eq::ModelingToolkit.Equation, term::Union{Symbolic,Unitful.Quantity,Num}; info::String = "") = _validate([eq.lhs, eq.rhs, term], ["left","right","noise"], info = info)
93+
validate(eq::ModelingToolkit.Equation, terms::Vector; info::String = "") = _validate(vcat([eq.lhs, eq.rhs], terms), vcat(["left", "right"], "noise #".*string.(1:length(terms))), info = info)
9594

9695
"Returns true iff units of equations are valid."
9796
validate(eqs::Vector; info::String = "") = all([validate(eqs[idx], info = info*"in eq. #$idx") for idx in 1:length(eqs)])

0 commit comments

Comments
 (0)