Skip to content

Commit e037dfe

Browse files
Merge pull request #2898 from contradict/variable_value_units
Make default value units consistent
2 parents 3f7ad46 + 49cf9ef commit e037dfe

File tree

4 files changed

+145
-33
lines changed

4 files changed

+145
-33
lines changed

src/systems/model_parsing.jl

Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function _model_macro(mod, name, expr, isconnector)
5151
c_evts = []
5252
d_evts = []
5353
kwargs = OrderedCollections.OrderedSet()
54-
where_types = Expr[]
54+
where_types = Union{Symbol, Expr}[]
5555

5656
push!(exprs.args, :(variables = []))
5757
push!(exprs.args, :(parameters = []))
@@ -143,9 +143,15 @@ end
143143
pop_structure_dict!(dict, key) = length(dict[key]) == 0 && pop!(dict, key)
144144

145145
function update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
146-
varclass, where_types)
146+
varclass, where_types, meta)
147147
if indices isa Nothing
148-
push!(kwargs, Expr(:kw, Expr(:(::), a, Union{Nothing, type}), nothing))
148+
if !isnothing(meta) && haskey(meta, VariableUnit)
149+
uvar = gensym()
150+
push!(where_types, uvar)
151+
push!(kwargs, Expr(:kw, :($a::Union{Nothing, $uvar}), nothing))
152+
else
153+
push!(kwargs, Expr(:kw, :($a::Union{Nothing, $type}), nothing))
154+
end
149155
dict[:kwargs][getname(var)] = Dict(:value => def, :type => type)
150156
else
151157
vartype = gensym(:T)
@@ -154,7 +160,11 @@ function update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
154160
Expr(:(::), a,
155161
Expr(:curly, :Union, :Nothing, Expr(:curly, :AbstractArray, vartype))),
156162
nothing))
157-
push!(where_types, :($vartype <: $type))
163+
if !isnothing(meta) && haskey(meta, VariableUnit)
164+
push!(where_types, vartype)
165+
else
166+
push!(where_types, :($vartype <: $type))
167+
end
158168
dict[:kwargs][getname(var)] = Dict(:value => def, :type => AbstractArray{type})
159169
end
160170
if dict[varclass] isa Vector
@@ -166,7 +176,7 @@ end
166176

167177
function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
168178
def = nothing, indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing,
169-
type::Type = Real)
179+
type::Type = Real, meta = Dict{DataType, Expr}())
170180
metatypes = [(:connection_type, VariableConnectType),
171181
(:description, VariableDescription),
172182
(:unit, VariableUnit),
@@ -186,29 +196,31 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
186196
a::Symbol => begin
187197
var = generate_var!(dict, a, varclass; indices, type)
188198
update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
189-
varclass, where_types)
199+
varclass, where_types, meta)
190200
return var, def, Dict()
191201
end
192202
Expr(:(::), a, type) => begin
193203
type = getfield(mod, type)
194-
parse_variable_def!(dict, mod, a, varclass, kwargs, where_types; def, type)
204+
parse_variable_def!(
205+
dict, mod, a, varclass, kwargs, where_types; def, type, meta)
195206
end
196207
Expr(:(::), Expr(:call, a, b), type) => begin
197208
type = getfield(mod, type)
198209
def = _type_check!(def, a, type, varclass)
199-
parse_variable_def!(dict, mod, a, varclass, kwargs, where_types; def, type)
210+
parse_variable_def!(
211+
dict, mod, a, varclass, kwargs, where_types; def, type, meta)
200212
end
201213
Expr(:call, a, b) => begin
202214
var = generate_var!(dict, a, b, varclass, mod; indices, type)
203215
update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
204-
varclass, where_types)
216+
varclass, where_types, meta)
205217
return var, def, Dict()
206218
end
207219
Expr(:(=), a, b) => begin
208220
Base.remove_linenums!(b)
209221
def, meta = parse_default(mod, b)
210222
var, def, _ = parse_variable_def!(
211-
dict, mod, a, varclass, kwargs, where_types; def, type)
223+
dict, mod, a, varclass, kwargs, where_types; def, type, meta)
212224
if dict[varclass] isa Vector
213225
dict[varclass][1][getname(var)][:default] = def
214226
else
@@ -231,9 +243,9 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
231243
return var, def, Dict()
232244
end
233245
Expr(:tuple, a, b) => begin
234-
var, def, _ = parse_variable_def!(
235-
dict, mod, a, varclass, kwargs, where_types; type)
236246
meta = parse_metadata(mod, b)
247+
var, def, _ = parse_variable_def!(
248+
dict, mod, a, varclass, kwargs, where_types; type, meta)
237249
if meta !== nothing
238250
for (type, key) in metatypes
239251
if (mt = get(meta, key, nothing)) !== nothing
@@ -253,7 +265,7 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
253265
Expr(:ref, a, b...) => begin
254266
indices = map(i -> UnitRange(i.args[2], i.args[end]), b)
255267
parse_variable_def!(dict, mod, a, varclass, kwargs, where_types;
256-
def, indices, type)
268+
def, indices, type, meta)
257269
end
258270
_ => error("$arg cannot be parsed")
259271
end
@@ -611,16 +623,58 @@ function parse_variable_arg!(exprs, vs, dict, mod, arg, varclass, kwargs, where_
611623
push!(exprs, ex)
612624
end
613625

626+
function convert_units(varunits::DynamicQuantities.Quantity, value)
627+
DynamicQuantities.ustrip(DynamicQuantities.uconvert(
628+
DynamicQuantities.SymbolicUnits.as_quantity(varunits), value))
629+
end
630+
631+
function convert_units(
632+
varunits::DynamicQuantities.Quantity, value::AbstractArray{T}) where {T}
633+
DynamicQuantities.ustrip.(DynamicQuantities.uconvert.(
634+
DynamicQuantities.SymbolicUnits.as_quantity(varunits), value))
635+
end
636+
637+
function convert_units(varunits::Unitful.FreeUnits, value)
638+
Unitful.ustrip(varunits, value)
639+
end
640+
641+
function convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where {T}
642+
Unitful.ustrip.(varunits, value)
643+
end
644+
614645
function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types)
615646
vv, def, metadata_with_exprs = parse_variable_def!(
616647
dict, mod, arg, varclass, kwargs, where_types)
617648
name = getname(vv)
618649

619-
varexpr = quote
620-
$name = if $name === nothing
621-
$setdefault($vv, $def)
622-
else
623-
$setdefault($vv, $name)
650+
varexpr = if haskey(metadata_with_exprs, VariableUnit)
651+
unit = metadata_with_exprs[VariableUnit]
652+
quote
653+
$name = if $name === nothing
654+
$setdefault($vv, $def)
655+
else
656+
try
657+
$setdefault($vv, $convert_units($unit, $name))
658+
catch e
659+
if isa(e, $(DynamicQuantities.DimensionError)) ||
660+
isa(e, $(Unitful.DimensionError))
661+
error("Unable to convert units for \'" * string(:($$vv)) * "\'")
662+
elseif isa(e, MethodError)
663+
error("No or invalid units provided for \'" * string(:($$vv)) *
664+
"\'")
665+
else
666+
rethrow(e)
667+
end
668+
end
669+
end
670+
end
671+
else
672+
quote
673+
$name = if $name === nothing
674+
$setdefault($vv, $def)
675+
else
676+
$setdefault($vv, $name)
677+
end
624678
end
625679
end
626680

test/dq_units.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,31 @@ maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1])
157157
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
158158
maj2 = MassActionJump(γ, [S => 1], [S => -1])
159159
@named js4 = JumpSystem([maj1, maj2], ModelingToolkit.t_nounits, [S], [β, γ])
160+
161+
@mtkmodel ParamTest begin
162+
@parameters begin
163+
a, [unit = u"m"]
164+
end
165+
@variables begin
166+
b(t), [unit = u"kg"]
167+
end
168+
end
169+
170+
@named sys = ParamTest()
171+
172+
@named sys = ParamTest(a = 3.0u"cm")
173+
@test ModelingToolkit.getdefault(sys.a) 0.03
174+
175+
@test_throws ErrorException ParamTest(; name = :t, a = 1.0)
176+
@test_throws ErrorException ParamTest(; name = :t, a = 1.0u"s")
177+
178+
@mtkmodel ArrayParamTest begin
179+
@parameters begin
180+
a[1:2], [unit = u"m"]
181+
end
182+
end
183+
184+
@named sys = ArrayParamTest()
185+
186+
@named sys = ArrayParamTest(a = [1.0, 3.0]u"cm")
187+
@test ModelingToolkit.getdefault(sys.a) [0.01, 0.03]

test/model_parsing.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using ModelingToolkit, Test
22
using ModelingToolkit: get_connector_type, get_defaults, get_gui_metadata,
33
get_systems, get_ps, getdefault, getname, readable_code,
4-
scalarize, symtype, VariableDescription, RegularConnector
4+
scalarize, symtype, VariableDescription, RegularConnector,
5+
get_unit
56
using URIs: URI
67
using Distributions
78
using DynamicQuantities, OrdinaryDiffEq
@@ -53,8 +54,9 @@ end
5354
end
5455
end
5556

56-
@named p = Pin(; v = π)
57-
@test getdefault(p.v) == π
57+
@named p = Pin(; v = π * u"V")
58+
59+
@test getdefault(p.v) π
5860
@test Pin.isconnector == true
5961

6062
@mtkmodel OnePort begin
@@ -76,7 +78,6 @@ end
7678

7779
@test OnePort.isconnector == false
7880

79-
resistor_log = "$(@__DIR__)/logo/resistor.svg"
8081
@mtkmodel Resistor begin
8182
@extend v, i = oneport = OnePort()
8283
@parameters begin
@@ -105,14 +106,14 @@ end
105106
@parameters begin
106107
C, [unit = u"F"]
107108
end
108-
@extend OnePort(; v = 0.0)
109+
@extend OnePort(; v = 0.0u"V")
109110
@icon "https://upload.wikimedia.org/wikipedia/commons/7/78/Capacitor_symbol.svg"
110111
@equations begin
111112
D(v) ~ i / C
112113
end
113114
end
114115

115-
@named capacitor = Capacitor(C = 10, v = 10.0)
116+
@named capacitor = Capacitor(C = 10u"F", v = 10.0u"V")
116117
@test getdefault(capacitor.v) == 10.0
117118

118119
@mtkmodel Voltage begin
@@ -127,9 +128,9 @@ end
127128

128129
@mtkmodel RC begin
129130
@structural_parameters begin
130-
R_val = 10
131-
C_val = 10
132-
k_val = 10
131+
R_val = 10u"Ω"
132+
C_val = 10u"F"
133+
k_val = 10u"V"
133134
end
134135
@components begin
135136
resistor = Resistor(; R = R_val)
@@ -147,9 +148,9 @@ end
147148
end
148149
end
149150

150-
C_val = 20
151-
R_val = 20
152-
res__R = 100
151+
C_val = 20u"F"
152+
R_val = 20u"Ω"
153+
res__R = 100u"Ω"
153154
@mtkbuild rc = RC(; C_val, R_val, resistor.R = res__R)
154155
prob = ODEProblem(rc, [], (0, 1e9))
155156
sol = solve(prob, Rodas5P())
@@ -160,11 +161,12 @@ resistor = getproperty(rc, :resistor; namespace = false)
160161
@test getname(rc.resistor.R) === getname(resistor.R)
161162
@test getname(rc.resistor.v) === getname(resistor.v)
162163
# Test that `resistor.R` overrides `R_val` in the argument.
163-
@test getdefault(rc.resistor.R) == res__R != R_val
164+
@test getdefault(rc.resistor.R) * get_unit(rc.resistor.R) == res__R != R_val
164165
# Test that `C_val` passed via argument is set as default of C.
165-
@test getdefault(rc.capacitor.C) == C_val
166+
@test getdefault(rc.capacitor.C) * get_unit(rc.capacitor.C) == C_val
166167
# Test that `k`'s default value is unchanged.
167-
@test getdefault(rc.constant.k) == RC.structure[:kwargs][:k_val][:value]
168+
@test getdefault(rc.constant.k) * get_unit(rc.constant.k) ==
169+
eval(RC.structure[:kwargs][:k_val][:value])
168170
@test getdefault(rc.capacitor.v) == 0.0
169171

170172
@test get_gui_metadata(rc.resistor).layout == Resistor.structure[:icon] ==

test/units.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,31 @@ maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1])
192192
maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
193193
maj2 = MassActionJump(γ, [S => 1], [S => -1])
194194
@named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ])
195+
196+
@mtkmodel ParamTest begin
197+
@parameters begin
198+
a, [unit = u"m"]
199+
end
200+
@variables begin
201+
b(t), [unit = u"kg"]
202+
end
203+
end
204+
205+
@named sys = ParamTest()
206+
207+
@named sys = ParamTest(a = 3.0u"cm")
208+
@test ModelingToolkit.getdefault(sys.a) 0.03
209+
210+
@test_throws ErrorException ParamTest(; name = :t, a = 1.0)
211+
@test_throws ErrorException ParamTest(; name = :t, a = 1.0u"s")
212+
213+
@mtkmodel ArrayParamTest begin
214+
@parameters begin
215+
a[1:2], [unit = u"m"]
216+
end
217+
end
218+
219+
@named sys = ArrayParamTest()
220+
221+
@named sys = ArrayParamTest(a = [1.0, 3.0]u"cm")
222+
@test ModelingToolkit.getdefault(sys.a) [0.01, 0.03]

0 commit comments

Comments
 (0)