Skip to content

Commit 3383d35

Browse files
committed
Refactor to accomodate noise equations for SDESystems. Extended to cover Control, Nonlinear, PDE, & Discrete systems. Renamed functions to reflect what they do now.
1 parent 15d60f2 commit 3383d35

File tree

8 files changed

+130
-71
lines changed

8 files changed

+130
-71
lines changed

src/systems/control/controlsystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ struct ControlSystem <: AbstractControlSystem
7777
check_parameters(ps, iv)
7878
check_equations(deqs, iv)
7979
check_equations(observed, iv)
80+
check_units(deqs)
8081
new(loss, deqs, iv, dvs, controls, ps, observed, name, systems, defaults)
8182
end
8283
end

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ struct ODESystem <: AbstractODESystem
8888
check_variables(dvs,iv)
8989
check_parameters(ps,iv)
9090
check_equations(deqs,iv)
91-
validate(deqs)
91+
check_units(deqs)
9292
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
9393
end
9494
end

src/systems/diffeqs/sdesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ struct SDESystem <: AbstractODESystem
9090
check_variables(dvs,iv)
9191
check_parameters(ps,iv)
9292
check_equations(deqs,iv)
93-
validate(deqs)
93+
check_units(deqs,neqs)
9494
new(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
9595
end
9696
end

src/systems/diffeqs/validation.jl

Lines changed: 69 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,105 @@
11
Base.:*(x::Union{Num,Symbolic},y::Unitful.AbstractQuantity) = x * y
22

3-
4-
function vartype(x::Symbolic)
5-
if !(x.metadata isa Nothing)
6-
return haskey(x.metadata, VariableUnit) ? x.metadata[VariableUnit] : 1
7-
end
8-
1
9-
end
10-
vartype(x::Num) = vartype(value(x))
11-
12-
instantiate(x) = 1
13-
instantiate(x::Unitful.Quantity) = 1 * Unitful.unit(x)
14-
instantiate(x::Num) = instantiate(value(x))
15-
function instantiate(x::Symbolic)
3+
"Find the units of a symbolic item."
4+
get_units(x) = 1
5+
get_units(x::Unitful.Quantity) = 1 * Unitful.unit(x)
6+
get_units(x::Num) = get_units(value(x))
7+
function get_units(x::Symbolic)
168
vx = value(x)
179
if vx isa Sym || operation(vx) isa Sym || (operation(vx) isa Term && operation(vx).f == getindex) || vx isa Symbolics.ArrayOp
18-
return oneunit(1 * vartype(vx))
10+
if x.metadata !== nothing
11+
symunits = haskey(x.metadata, VariableUnit) ? x.metadata[VariableUnit] : 1
12+
else
13+
symunits = 1
14+
end
15+
return oneunit(1 * symunits)
1916
elseif operation(vx) isa Differential || operation(vx) isa Difference
20-
return instantiate(arguments(vx)[1]) / instantiate(arguments(arguments(vx)[1])[1])
17+
return get_units(arguments(vx)[1]) / get_units(arguments(arguments(vx)[1])[1])
2118
elseif vx isa Pow
2219
pargs = arguments(vx)
23-
base,expon = instantiate.(pargs)
20+
base,expon = get_units.(pargs)
2421
uconvert(NoUnits, expon) # This acts as an assertion
2522
return base == 1 ? 1 : operation(vx)(base, pargs[2])
2623
elseif vx isa Add # Cannot simply add the units b/c they may differ in magnitude (eg, kg vs g)
27-
terms = instantiate.(arguments(vx))
24+
terms = get_units.(arguments(vx))
2825
firstunit = unit(terms[1])
2926
@assert all(map(x -> ustrip(firstunit, x) == 1, terms[2:end]))
3027
return 1 * firstunit
3128
elseif operation(vx) == Symbolics._mapreduce
3229
if vx.arguments[2] == +
33-
instantiate(vx.arguments[3])
30+
get_units(vx.arguments[3])
3431
else
3532
throw(ArgumentError("Unknown array operation $vx"))
3633
end
3734
else
38-
return oneunit(operation(vx)(instantiate.(arguments(vx))...))
35+
return oneunit(operation(vx)(get_units.(arguments(vx))...))
3936
end
4037
end
4138

42-
function validate(eq::ModelingToolkit.Equation; eqnum = 1)
43-
lhs = rhs = nothing
39+
"Get units of term, returning nothing & showing warning instead of throwing errors."
40+
function safe_get_units(term, info)
41+
side = nothing
4442
try
45-
lhs = instantiate(eq.lhs)
43+
side = get_units(term)
4644
catch err
4745
if err isa Unitful.DimensionError
48-
@warn("In left-hand side of eq. #$eqnum: $(eq.lhs), $(err.x) and $(err.y) are not dimensionally compatible.")
49-
elseif err isa MethodError
50-
@warn("In left-hand side of eq. #$eqnum: $(err.f) doesn't accept $(err.args).")
46+
@warn("$info: $(err.x) and $(err.y) are not dimensionally compatible.")
47+
elseif err isa MethodError #TODO: filter for only instances where the arguments are unitful
48+
@warn("$info: no method matching $(err.f) for arguments $(err.args).")
5149
else
5250
rethrow()
5351
end
5452
end
55-
try
56-
rhs = instantiate(eq.rhs)
57-
catch err
58-
if err isa Unitful.DimensionError
59-
@warn("In right-hand side of eq. #$eqnum: $(eq.rhs), $(err.x) and $(err.y) are not dimensionally compatible.")
60-
elseif err isa MethodError
61-
@warn("In right-hand side of eq. #$eqnum: $(err.f) doesn't accept $(err.args).")
62-
else
63-
rethrow()
64-
end
65-
end
66-
if (rhs !== nothing) && (lhs !== nothing)
67-
if !isequal(lhs, rhs)
68-
@warn("In eq. #$eqnum, left-side units ($lhs) and right-side units ($rhs) don't match.")
53+
side
54+
end
55+
56+
function _validate(terms::Vector,labels::Vector; info::String = "")
57+
equnits = safe_get_units.(terms,info.*labels)
58+
allthere = all(map(x->x!==nothing,equnits))
59+
allmatching = true
60+
if allthere
61+
for idx in 2:length(equnits)
62+
if !isequal(equnits[1],equnits[idx])
63+
allmatching = false
64+
@warn("$info: units $(equnits[1]) for $(labels[1]) and $(equnits[idx]) for $(labels[idx]) do not match.")
65+
end
6966
end
7067
end
71-
(rhs !== nothing) && (lhs !== nothing) && isequal(lhs, rhs)
68+
allthere && allmatching
69+
end
70+
71+
function validate(eq::ModelingToolkit.Equation; info::String = "")
72+
labels = ["left-hand side", "right-hand side"]
73+
terms = [eq.lhs,eq.rhs]
74+
_validate(terms,labels,info = info)
75+
end
76+
77+
function validate(eq::ModelingToolkit.Equation,noiseterm; info::String = "")
78+
labels = ["left-hand side", "right-hand side","noise term"]
79+
terms = [eq.lhs,eq.rhs,noiseterm]
80+
_validate(terms,labels,info = info)
81+
end
82+
83+
function validate(eq::ModelingToolkit.Equation,noisevec::Vector; info::String = "")
84+
labels = vcat(["left-hand side", "right-hand side"],"noise term #".* string.(1:length(noisevec)))
85+
terms = vcat([eq.lhs,eq.rhs],noisevec)
86+
_validate(terms,labels,info = info)
7287
end
7388

7489
function validate(eqs::Vector{ModelingToolkit.Equation})
75-
correct = [validate(eqs[idx],eqnum=idx) for idx in 1:length(eqs)]
76-
all(correct) || throw(ArgumentError("Invalid equations, see warnings for details."))
90+
all([validate(eqs[idx],info = "In eq. #$idx") for idx in 1:length(eqs)])
91+
end
92+
93+
function validate(eqs::Vector{ModelingToolkit.Equation},noise::Vector)
94+
all([validate(eqs[idx],noise[idx],info = "In eq. #$idx") for idx in 1:length(eqs)])
7795
end
7896

79-
validate(sys::AbstractODESystem) = validate(equations(sys))
97+
function validate(eqs::Vector{ModelingToolkit.Equation},noise::Matrix)
98+
all([validate(eqs[idx],noise[idx,:],info = "In eq. #$idx") for idx in 1:length(eqs)])
99+
end
100+
101+
"Returns true iff units of equations are valid."
102+
validate(eqs::Vector) = validate(convert.(ModelingToolkit.Equation,eqs))
103+
104+
"Throws error if units of equations are invalid."
105+
check_units(eqs...) = validate(eqs...) || throw(ArgumentError("Some equations had invalid units. See warnings for details."))

src/systems/discrete_system/discrete_system.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ struct DiscreteSystem <: AbstractSystem
5757
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, default_u0, default_p)
5858
check_variables(dvs,iv)
5959
check_parameters(ps,iv)
60+
check_units(discreteEqs)
6061
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, default_u0, default_p)
6162
end
6263
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ struct NonlinearSystem <: AbstractSystem
5454
type: type of the system
5555
"""
5656
connection_type::Any
57+
function NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, structure, connection_type)
58+
check_units(eqs)
59+
new(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, structure, connection_type)
60+
end
5761
end
5862

5963
function NonlinearSystem(eqs, states, ps;

src/systems/pde/pdesystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ struct PDESystem <: ModelingToolkit.AbstractSystem
6161
defaults=Dict(),
6262
connection_type=nothing,
6363
)
64+
check_units(eqs)
6465
new(eqs, bcs, domain, indvars, depvars, ps, defaults, connection_type)
6566
end
6667
end

test/units.jl

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,41 +5,40 @@ MT = ModelingToolkit
55
@variables t [unit = u"ms"] E(t) [unit = u"kJ"] P(t) [unit = u"MW"]
66
D = Differential(t)
77

8-
@test MT.vartype(t) == u"ms"
9-
@test MT.vartype(E) == u"kJ"
10-
@test MT.vartype(τ) == u"ms"
11-
12-
13-
@test MT.instantiate(0.5) == 1.0
14-
@test MT.instantiate(t) == 1.0u"ms"
15-
@test MT.instantiate(P) == 1.0u"MW"
16-
@test MT.instantiate(τ) == 1.0u"ms"
17-
18-
@test MT.instantiate^-1) == 1/u"ms"
19-
@test MT.instantiate(D(E)) == 1.0u"MW"
20-
@test MT.instantiate(E/τ) == 1.0u"MW"
21-
@test MT.instantiate(2*P) == 1.0u"MW"
22-
@test MT.instantiate(t/τ) == 1.0
23-
@test MT.instantiate(P - E/τ)/1.0u"MW" == 1.0
24-
25-
@test MT.instantiate(1.0^(t/τ)) == 1.0
26-
@test MT.instantiate(exp(t/τ)) == 1.0
27-
@test MT.instantiate(sin(t/τ)) == 1.0
28-
@test MT.instantiate(sin(1u"rad")) == 1.0
29-
@test MT.instantiate(t^2) == 1.0u"ms"^2
8+
@test MT.get_units(t) == 1u"ms"
9+
@test MT.get_units(E) == 1u"kJ"
10+
@test MT.get_units(τ) == 1u"ms"
11+
12+
@test MT.get_units(0.5) == 1.0
13+
@test MT.get_units(t) == 1.0u"ms"
14+
@test MT.get_units(P) == 1.0u"MW"
15+
@test MT.get_units(τ) == 1.0u"ms"
16+
17+
@test MT.get_units^-1) == 1/u"ms"
18+
@test MT.get_units(D(E)) == 1.0u"MW"
19+
@test MT.get_units(E/τ) == 1.0u"MW"
20+
@test MT.get_units(2*P) == 1.0u"MW"
21+
@test MT.get_units(t/τ) == 1.0
22+
@test MT.get_units(P - E/τ)/1.0u"MW" == 1.0
23+
24+
@test MT.get_units(1.0^(t/τ)) == 1.0
25+
@test MT.get_units(exp(t/τ)) == 1.0
26+
@test MT.get_units(sin(t/τ)) == 1.0
27+
@test MT.get_units(sin(1u"rad")) == 1.0
28+
@test MT.get_units(t^2) == 1.0u"ms"^2
3029

3130
@test !MT.validate(E^1.5 ~ E^(t/τ))
3231
@test MT.validate(E^(t/τ) ~ E^(t/τ))
3332

3433
eqs = [D(E) ~ P - E/τ
3534
0.0u"MW" ~ P]
36-
@test MT.instantiate(eqs[1].lhs) == 1.0u"MW"
37-
@test MT.instantiate(eqs[1].rhs) == 1.0u"MW"
35+
@test MT.get_units(eqs[1].lhs) == 1.0u"MW"
36+
@test MT.get_units(eqs[1].rhs) == 1.0u"MW"
3837
@test MT.validate(eqs[1])
3938
@test MT.validate(eqs[2])
39+
@test MT.validate(eqs)
4040
sys = ODESystem(eqs)
4141
sys = ODESystem(eqs, t, [P, E], [τ])
42-
@test MT.validate(sys)
4342

4443
@test !MT.validate(D(D(E)) ~ P)
4544
@test !MT.validate(0 ~ P + E*τ)
@@ -112,4 +111,31 @@ eqs = [D(y[1]) ~ -k[1]*y[1] + k[3]*y[2]*y[3],
112111
D(y[2]) ~ k[1]*y[1] - k[3]*y[2]*y[3] - k[2]*y[2]^2,
113112
0 ~ y[1] + y[2] + y[3] - 1]
114113

115-
sys = ODESystem(eqs,t,y,k)
114+
sys = ODESystem(eqs,t,y,k)
115+
116+
# Nonlinear system
117+
@parameters a [unit = u"kg"^-1]
118+
@variables x [unit = u"kg"]
119+
eqs = [
120+
0 ~ a*x
121+
]
122+
nls = NonlinearSystem(eqs, [x], [a])
123+
124+
# SDE test w/ noise vector
125+
@parameters τ [unit = u"ms"] Q [unit = u"MW"]
126+
@variables t [unit = u"ms"] E(t) [unit = u"kJ"] P(t) [unit = u"MW"]
127+
D = Differential(t)
128+
eqs = [D(E) ~ P - E/τ
129+
P ~ Q]
130+
131+
noiseeqs = [0.1u"MW",
132+
0.1u"MW"]
133+
sys = SDESystem(eqs,noiseeqs,t,[P,E],[τ,Q])
134+
# With noise matrix
135+
noiseeqs = [0.1u"MW" 0.1u"MW"
136+
0.1u"MW" 0.1u"MW"]
137+
sys = SDESystem(eqs,noiseeqs,t,[P,E],[τ,Q])
138+
139+
noiseeqs = [0.1u"MW" 0.1u"MW"
140+
0.1u"MW" 0.1u"s"]
141+
@test !MT.validate(eqs,noiseeqs)

0 commit comments

Comments
 (0)