Skip to content

Commit 8436903

Browse files
Merge pull request #2521 from ven-k/vkb/array-type
fix: type of array in the mtkmodel's `f`
2 parents 5c03ad3 + 5ee52dc commit 8436903

File tree

3 files changed

+129
-73
lines changed

3 files changed

+129
-73
lines changed

docs/src/basics/MTKModel_Connector.md

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ julia> @mtkbuild model_c2 = ModelC(; p1 = 2.0)
140140
- Whenever components are created with `@named` macro, these can be accessed with `.` operator as `subcomponent_name.argname`
141141
- In the above example, as `k` of `model_a` isn't listed while defining the sub-component in `ModelC`, its default value can't be modified by users. While `k_array` can be set as:
142142

143-
```julia
143+
```@example mtkmodel-example
144144
using ModelingToolkit: getdefault
145145
146146
@mtkbuild model_c3 = ModelC(; model_a.k_array = [1.0, 2.0])
@@ -149,13 +149,6 @@ getdefault(model_c3.model_a.k_array[1])
149149
# 1.0
150150
getdefault(model_c3.model_a.k_array[2])
151151
# 2.0
152-
153-
@mtkbuild model_c4 = ModelC(model_a.k_array = 3.0)
154-
155-
getdefault(model_c4.model_a.k_array[1])
156-
# 3.0
157-
getdefault(model_c4.model_a.k_array[2])
158-
# 3.0
159152
```
160153

161154
#### `@equations` begin block
@@ -242,15 +235,16 @@ For example, the structure of `ModelC` is:
242235

243236
```julia
244237
julia> ModelC.structure
245-
Dict{Symbol, Any} with 7 entries:
246-
:components => [[:model_a, :ModelA]]
247-
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var), :v_array=>Dict(:size=>(2, 3)))
238+
Dict{Symbol, Any} with 9 entries:
239+
:components => Any[Union{Expr, Symbol}[:model_a, :ModelA]]
240+
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_array=>Dict(:type=>Real, :size=>(2, 3)))
248241
:icon => URI("https://github.com/SciML/SciMLDocs/blob/main/docs/src/assets/logo.png")
249242
:kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :v=>Dict{Symbol, Union{Nothing, Symbol}}(:value=>:v_var, :type=>Real), :v_array=>Dict(:value=>nothing, :type=>Real), :p1=>Dict(:value=>nothing))
250243
:structural_parameters => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin))
251244
:independent_variable => t
245+
:constants => Dict{Symbol, Dict}(:c=>Dict(:value=>1))
252246
:extend => Any[[:p2, :p1], Symbol("#mtkmodel__anonymous__ModelB"), :ModelB]
253-
:equations => ["model_a.k ~ f(v)"]
247+
:equations => Any["model_a.k ~ f(v)"]
254248
```
255249

256250
### Using conditional statements
@@ -322,12 +316,12 @@ The conditional parts are reflected in the `structure`. For `BranchOutsideTheBlo
322316

323317
```julia
324318
julia> BranchOutsideTheBlock.structure
325-
Dict{Symbol, Any} with 5 entries:
326-
:components => Any[(:if, :flag, [[:sys1, :C]], Any[])]
319+
Dict{Symbol, Any} with 6 entries:
320+
:components => Any[(:if, :flag, Vector{Union{Expr, Symbol}}[[:sys1, :C]], Any[])]
327321
:kwargs => Dict{Symbol, Dict}(:flag=>Dict{Symbol, Bool}(:value=>1))
328322
:structural_parameters => Dict{Symbol, Dict}(:flag=>Dict{Symbol, Bool}(:value=>1))
329323
:independent_variable => t
330-
:parameters => Dict{Symbol, Dict{Symbol, Any}}(:a1=>Dict(:condition=>(:if, :flag, Dict{Symbol, Any}(:kwargs => Dict{Any, Any}(:a1 => nothing), :parameters => Any[Dict{Symbol, Dict{Symbol, Any}}(:a1 => Dict())]), Dict{Symbol, Any}(:kwargs => Dict{Any, Any}(:a2 => nothing), :parameters => Any[Dict{Symbol, Dict{Symbol, Any}}(:a2 => Dict())]))
324+
:parameters => Dict{Symbol, Dict{Symbol, Any}}(:a2 => Dict(:type => AbstractArray{Real}, :condition => (:if, :flag, Dict{Symbol, Any}(:kwargs => Dict{Any, Any}(:a1 => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real)), :parameters => Any[Dict{Symbol, Dict{Symbol, Any}}(:a1 => Dict(:type => AbstractArray{Real}))]), Dict{Symbol, Any}(:variables => Any[Dict{Symbol, Dict{Symbol, Any}}()], :kwargs => Dict{Any, Any}(:a2 => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real)), :parameters => Any[Dict{Symbol, Dict{Symbol, Any}}(:a2 => Dict(:type => AbstractArray{Real}))]))), :a1 => Dict(:type => AbstractArray{Real}, :condition => (:if, :flag, Dict{Symbol, Any}(:kwargs => Dict{Any, Any}(:a1 => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real)), :parameters => Any[Dict{Symbol, Dict{Symbol, Any}}(:a1 => Dict(:type => AbstractArray{Real}))]), Dict{Symbol, Any}(:variables => Any[Dict{Symbol, Dict{Symbol, Any}}()], :kwargs => Dict{Any, Any}(:a2 => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real)), :parameters => Any[Dict{Symbol, Dict{Symbol, Any}}(:a2 => Dict(:type => AbstractArray{Real}))]))))
331325
:equations => Any[(:if, :flag, ["a1 ~ 0"], ["a2 ~ 0"])]
332326
```
333327

src/systems/model_parsing.jl

Lines changed: 84 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ function _model_macro(mod, name, expr, isconnector)
4545
icon = Ref{Union{String, URI}}()
4646
ps, sps, vs, = [], [], []
4747
kwargs = Set()
48+
where_types = Expr[]
4849

4950
push!(exprs.args, :(variables = []))
5051
push!(exprs.args, :(parameters = []))
@@ -55,25 +56,28 @@ function _model_macro(mod, name, expr, isconnector)
5556
for arg in expr.args
5657
if arg.head == :macrocall
5758
parse_model!(exprs.args, comps, ext, eqs, icon, vs, ps,
58-
sps, dict, mod, arg, kwargs)
59+
sps, dict, mod, arg, kwargs, where_types)
5960
elseif arg.head == :block
6061
push!(exprs.args, arg)
6162
elseif arg.head == :if
6263
MLStyle.@match arg begin
6364
Expr(:if, condition, x) => begin
6465
parse_conditional_model_statements(comps, dict, eqs, exprs, kwargs,
65-
mod, ps, vs, parse_top_level_branch(condition, x.args)...)
66+
mod, ps, vs, where_types,
67+
parse_top_level_branch(condition, x.args)...)
6668
end
6769
Expr(:if, condition, x, y) => begin
6870
parse_conditional_model_statements(comps, dict, eqs, exprs, kwargs,
69-
mod, ps, vs, parse_top_level_branch(condition, x.args, y)...)
71+
mod, ps, vs, where_types,
72+
parse_top_level_branch(condition, x.args, y)...)
7073
end
7174
_ => error("Got an invalid argument: $arg")
7275
end
7376
elseif isconnector
7477
# Connectors can have variables listed without `@variables` prefix or
7578
# begin block.
76-
parse_variable_arg!(exprs.args, vs, dict, mod, arg, :variables, kwargs)
79+
parse_variable_arg!(
80+
exprs.args, vs, dict, mod, arg, :variables, kwargs, where_types)
7781
else
7882
error("$arg is not valid syntax. Expected a macro call.")
7983
end
@@ -104,11 +108,40 @@ function _model_macro(mod, name, expr, isconnector)
104108
isconnector && push!(exprs.args,
105109
:($Setfield.@set!(var"#___sys___".connector_type=$connector_type(var"#___sys___"))))
106110

107-
f = :($(Symbol(:__, name, :__))(; name, $(kwargs...)) = $exprs)
111+
f = if length(where_types) == 0
112+
:($(Symbol(:__, name, :__))(; name, $(kwargs...)) = $exprs)
113+
else
114+
f_with_where = Expr(:where)
115+
push!(f_with_where.args,
116+
:($(Symbol(:__, name, :__))(; name, $(kwargs...))), where_types...)
117+
:($f_with_where = $exprs)
118+
end
108119
:($name = $Model($f, $dict, $isconnector))
109120
end
110121

111-
function parse_variable_def!(dict, mod, arg, varclass, kwargs;
122+
function update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
123+
varclass, where_types)
124+
if indices isa Nothing
125+
push!(kwargs, Expr(:kw, Expr(:(::), a, Union{Nothing, type}), nothing))
126+
dict[:kwargs][getname(var)] = Dict(:value => def, :type => type)
127+
else
128+
vartype = gensym(:T)
129+
push!(kwargs,
130+
Expr(:kw,
131+
Expr(:(::), a,
132+
Expr(:curly, :Union, :Nothing, Expr(:curly, :AbstractArray, vartype))),
133+
nothing))
134+
push!(where_types, :($vartype <: $type))
135+
dict[:kwargs][getname(var)] = Dict(:value => def, :type => AbstractArray{type})
136+
end
137+
if dict[varclass] isa Vector
138+
dict[varclass][1][getname(var)][:type] = AbstractArray{type}
139+
else
140+
dict[varclass][getname(var)][:type] = type
141+
end
142+
end
143+
144+
function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
112145
def = nothing, indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing,
113146
type::Type = Real)
114147
metatypes = [(:connection_type, VariableConnectType),
@@ -128,40 +161,31 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs;
128161
arg isa LineNumberNode && return
129162
MLStyle.@match arg begin
130163
a::Symbol => begin
131-
if type isa Nothing
132-
push!(kwargs, Expr(:kw, a, nothing))
133-
else
134-
push!(kwargs, Expr(:kw, Expr(:(::), a, Union{Nothing, type}), nothing))
135-
end
136164
var = generate_var!(dict, a, varclass; indices, type)
137-
dict[:kwargs][getname(var)] = Dict(:value => def, :type => type)
165+
update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
166+
varclass, where_types)
138167
(var, def)
139168
end
140169
Expr(:(::), a, type) => begin
141-
type = Core.eval(mod, type)
142-
_type_check!(a, type)
143-
parse_variable_def!(dict, mod, a, varclass, kwargs; def, type)
170+
type = getfield(mod, type)
171+
parse_variable_def!(dict, mod, a, varclass, kwargs, where_types; def, type)
144172
end
145173
Expr(:(::), Expr(:call, a, b), type) => begin
146-
type = Core.eval(mod, type)
147-
def = _type_check!(def, a, type)
148-
parse_variable_def!(dict, mod, a, varclass, kwargs; def, type)
174+
type = getfield(mod, type)
175+
def = _type_check!(def, a, type, varclass)
176+
parse_variable_def!(dict, mod, a, varclass, kwargs, where_types; def, type)
149177
end
150178
Expr(:call, a, b) => begin
151-
if type isa Nothing
152-
push!(kwargs, Expr(:kw, a, nothing))
153-
else
154-
push!(kwargs, Expr(:kw, Expr(:(::), a, Union{Nothing, type}), nothing))
155-
end
156179
var = generate_var!(dict, a, b, varclass; indices, type)
157-
type !== nothing && (dict[varclass][getname(var)][:type] = type)
158-
dict[:kwargs][getname(var)] = Dict(:value => def, :type => type)
180+
update_kwargs_and_metadata!(dict, kwargs, a, def, indices, type, var,
181+
varclass, where_types)
159182
(var, def)
160183
end
161184
Expr(:(=), a, b) => begin
162185
Base.remove_linenums!(b)
163186
def, meta = parse_default(mod, b)
164-
var, def = parse_variable_def!(dict, mod, a, varclass, kwargs; def, type)
187+
var, def = parse_variable_def!(
188+
dict, mod, a, varclass, kwargs, where_types; def, type)
165189
if dict[varclass] isa Vector
166190
dict[varclass][1][getname(var)][:default] = def
167191
else
@@ -183,7 +207,8 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs;
183207
(var, def)
184208
end
185209
Expr(:tuple, a, b) => begin
186-
var, def = parse_variable_def!(dict, mod, a, varclass, kwargs; type)
210+
var, def = parse_variable_def!(
211+
dict, mod, a, varclass, kwargs, where_types; type)
187212
meta = parse_metadata(mod, b)
188213
if meta !== nothing
189214
for (type, key) in metatypes
@@ -202,7 +227,7 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs;
202227
end
203228
Expr(:ref, a, b...) => begin
204229
indices = map(i -> UnitRange(i.args[2], i.args[end]), b)
205-
parse_variable_def!(dict, mod, a, varclass, kwargs;
230+
parse_variable_def!(dict, mod, a, varclass, kwargs, where_types;
206231
def, indices, type)
207232
end
208233
_ => error("$arg cannot be parsed")
@@ -307,17 +332,17 @@ function get_var(mod::Module, b)
307332
end
308333

309334
function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps,
310-
dict, mod, arg, kwargs)
335+
dict, mod, arg, kwargs, where_types)
311336
mname = arg.args[1]
312337
body = arg.args[end]
313338
if mname == Symbol("@components")
314339
parse_components!(exprs, comps, dict, body, kwargs)
315340
elseif mname == Symbol("@extend")
316341
parse_extend!(exprs, ext, dict, mod, body, kwargs)
317342
elseif mname == Symbol("@variables")
318-
parse_variables!(exprs, vs, dict, mod, body, :variables, kwargs)
343+
parse_variables!(exprs, vs, dict, mod, body, :variables, kwargs, where_types)
319344
elseif mname == Symbol("@parameters")
320-
parse_variables!(exprs, ps, dict, mod, body, :parameters, kwargs)
345+
parse_variables!(exprs, ps, dict, mod, body, :parameters, kwargs, where_types)
321346
elseif mname == Symbol("@structural_parameters")
322347
parse_structural_parameters!(exprs, sps, dict, mod, body, kwargs)
323348
elseif mname == Symbol("@equations")
@@ -336,7 +361,7 @@ function parse_structural_parameters!(exprs, sps, dict, mod, body, kwargs)
336361
MLStyle.@match arg begin
337362
Expr(:(=), Expr(:(::), a, type), b) => begin
338363
type = Core.eval(mod, type)
339-
b = _type_check!(Core.eval(mod, b), a, type)
364+
b = _type_check!(Core.eval(mod, b), a, type, :structural_parameters)
340365
push!(sps, a)
341366
push!(kwargs, Expr(:kw, Expr(:(::), a, type), b))
342367
dict[:structural_parameters][a] = dict[:kwargs][a] = Dict(
@@ -454,25 +479,27 @@ function parse_extend!(exprs, ext, dict, mod, body, kwargs)
454479
return nothing
455480
end
456481

457-
function parse_variable_arg!(exprs, vs, dict, mod, arg, varclass, kwargs)
458-
name, ex = parse_variable_arg(dict, mod, arg, varclass, kwargs)
482+
function parse_variable_arg!(exprs, vs, dict, mod, arg, varclass, kwargs, where_types)
483+
name, ex = parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types)
459484
push!(vs, name)
460485
push!(exprs, ex)
461486
end
462487

463-
function parse_variable_arg(dict, mod, arg, varclass, kwargs)
464-
vv, def = parse_variable_def!(dict, mod, arg, varclass, kwargs)
488+
function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types)
489+
vv, def = parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types)
465490
name = getname(vv)
466491
return vv isa Num ? name : :($name...),
467492
:($name = $name === nothing ? $setdefault($vv, $def) : $setdefault($vv, $name))
468493
end
469494

470-
function handle_conditional_vars!(arg, conditional_branch, mod, varclass, kwargs)
495+
function handle_conditional_vars!(
496+
arg, conditional_branch, mod, varclass, kwargs, where_types)
471497
conditional_dict = Dict(:kwargs => Dict(),
472498
:parameters => Any[Dict{Symbol, Dict{Symbol, Any}}()],
473499
:variables => Any[Dict{Symbol, Dict{Symbol, Any}}()])
474500
for _arg in arg.args
475-
name, ex = parse_variable_arg(conditional_dict, mod, _arg, varclass, kwargs)
501+
name, ex = parse_variable_arg(
502+
conditional_dict, mod, _arg, varclass, kwargs, where_types)
476503
push!(conditional_branch.args, ex)
477504
push!(conditional_branch.args, :(push!($varclass, $name)))
478505
end
@@ -530,7 +557,7 @@ function push_conditional_dict!(dict, condition, conditional_dict,
530557
end
531558
end
532559

533-
function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs)
560+
function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs, where_types)
534561
expr = Expr(:block)
535562
push!(exprs, expr)
536563
for arg in body.args
@@ -542,7 +569,8 @@ function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs)
542569
conditional_expr.args[2],
543570
mod,
544571
varclass,
545-
kwargs)
572+
kwargs,
573+
where_types)
546574
push!(expr.args, conditional_expr)
547575
push_conditional_dict!(dict, condition, conditional_dict, nothing, varclass)
548576
end
@@ -552,12 +580,13 @@ function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs)
552580
conditional_expr.args[2],
553581
mod,
554582
varclass,
555-
kwargs)
583+
kwargs,
584+
where_types)
556585
conditional_y_expr, conditional_y_tuple = handle_y_vars(y,
557586
conditional_dict,
558587
mod,
559588
varclass,
560-
kwargs)
589+
kwargs, where_types)
561590
push!(conditional_expr.args, conditional_y_expr)
562591
push!(expr.args, conditional_expr)
563592
push_conditional_dict!(dict,
@@ -566,25 +595,28 @@ function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs)
566595
conditional_y_tuple,
567596
varclass)
568597
end
569-
_ => parse_variable_arg!(exprs, vs, dict, mod, arg, varclass, kwargs)
598+
_ => parse_variable_arg!(
599+
exprs, vs, dict, mod, arg, varclass, kwargs, where_types)
570600
end
571601
end
572602
end
573603

574-
function handle_y_vars(y, dict, mod, varclass, kwargs)
604+
function handle_y_vars(y, dict, mod, varclass, kwargs, where_types)
575605
conditional_dict = if Meta.isexpr(y, :elseif)
576606
conditional_y_expr = Expr(:elseif, y.args[1], Expr(:block))
577607
conditional_dict = handle_conditional_vars!(y.args[2],
578608
conditional_y_expr.args[2],
579609
mod,
580610
varclass,
581-
kwargs)
582-
_y_expr, _conditional_dict = handle_y_vars(y.args[end], dict, mod, varclass, kwargs)
611+
kwargs,
612+
where_types)
613+
_y_expr, _conditional_dict = handle_y_vars(
614+
y.args[end], dict, mod, varclass, kwargs, where_types)
583615
push!(conditional_y_expr.args, _y_expr)
584616
(:elseif, y.args[1], conditional_dict, _conditional_dict)
585617
else
586618
conditional_y_expr = Expr(:block)
587-
handle_conditional_vars!(y, conditional_y_expr, mod, varclass, kwargs)
619+
handle_conditional_vars!(y, conditional_y_expr, mod, varclass, kwargs, where_types)
588620
end
589621
conditional_y_expr, conditional_dict
590622
end
@@ -865,18 +897,18 @@ function parse_top_level_branch(condition, x, y = nothing, branch = :if)
865897
end
866898

867899
function parse_conditional_model_statements(comps, dict, eqs, exprs, kwargs, mod,
868-
ps, vs, component_blk, equations_blk, parameter_blk, variable_blk)
900+
ps, vs, where_types, component_blk, equations_blk, parameter_blk, variable_blk)
869901
parameter_blk !== nothing &&
870902
parse_variables!(
871903
exprs.args, ps, dict, mod, :(begin
872904
$parameter_blk
873-
end), :parameters, kwargs)
905+
end), :parameters, kwargs, where_types)
874906

875907
variable_blk !== nothing &&
876908
parse_variables!(
877909
exprs.args, vs, dict, mod, :(begin
878910
$variable_blk
879-
end), :variables, kwargs)
911+
end), :variables, kwargs, where_types)
880912

881913
component_blk !== nothing &&
882914
parse_components!(exprs.args,
@@ -890,8 +922,7 @@ function parse_conditional_model_statements(comps, dict, eqs, exprs, kwargs, mod
890922
end))
891923
end
892924

893-
_type_check!(a, type) = return
894-
function _type_check!(val, a, type)
925+
function _type_check!(val, a, type, varclass)
895926
if val isa type
896927
return val
897928
else
@@ -900,7 +931,7 @@ function _type_check!(val, a, type)
900931
catch
901932
(e)
902933
throw(TypeError(Symbol("`@mtkmodel`"),
903-
"`@structural_parameters`, while assigning to `$a`", type, typeof(val)))
934+
"`$varclass`, while assigning to `$a`", type, typeof(val)))
904935
end
905936
end
906937
end

0 commit comments

Comments
 (0)