Skip to content

Commit a64aad8

Browse files
Merge pull request #2368 from ven-k/vkb/component-arrays
Add support for component array in `@mtkmodel`
2 parents 2a28c4d + e324633 commit a64aad8

File tree

4 files changed

+174
-54
lines changed

4 files changed

+174
-54
lines changed

docs/src/basics/MTKModel_Connector.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ end
5858
end
5959
@structural_parameters begin
6060
f = sin
61+
N = 2
6162
end
6263
begin
6364
v_var = 1.0
@@ -69,6 +70,11 @@ end
6970
@extend ModelB(; p1)
7071
@components begin
7172
model_a = ModelA(; k_array)
73+
model_array_a = [ModelA(; k = i) for i in 1:N]
74+
model_array_b = for i in 1:N
75+
k = i^2
76+
ModelA(; k)
77+
end
7278
end
7379
@equations begin
7480
model_a.k ~ f(v)
@@ -146,6 +152,7 @@ julia> @mtkbuild model_c2 = ModelC(; p1 = 2.0)
146152
#### `@components` begin block
147153

148154
- Declare the subcomponents within `@components` begin block.
155+
- Array of components can be declared with a for loop or a list comprehension.
149156
- The arguments in these subcomponents are promoted as keyword arguments as `subcomponent_name__argname` with `nothing` as default value.
150157
- Whenever components are created with `@named` macro, these can be accessed with `.` operator as `subcomponent_name.argname`
151158
- In the above example, as `k` of `model_a` isn't listed while defining the sub-component in `ModelC`, its default value can't be modified by users. While `k_array` can be set as:
@@ -247,14 +254,13 @@ For example, the structure of `ModelC` is:
247254
```julia
248255
julia> ModelC.structure
249256
Dict{Symbol, Any} with 9 entries:
250-
:components => Any[Union{Expr, Symbol}[:model_a, :ModelA]]
257+
:components => Any[Union{Expr, Symbol}[:model_a, :ModelA], Union{Expr, Symbol}[:model_array_a, :ModelA, :(1:N)], Union{Expr, Symbol}[:model_array_b, :ModelA, :(1:N)]]
251258
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_array=>Dict(:type=>Real, :size=>(2, 3)))
252259
:icon => URI("https://github.com/SciML/SciMLDocs/blob/main/docs/src/assets/logo.png")
253-
:constants => Dict{Symbol, Dict}(:c=>Dict{Symbol, Any}(:value=>1, :type=>Int64, :description=>"Example constant."))
254-
:kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :v=>Dict{Symbol, Union{Nothing, Symbol}}(:value=>:v_var, :type=>Real), :v_array=>Dict(:value=>nothing, :type=>Real), :p1=>Dict(:value=>nothing))
255-
:structural_parameters => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin))
260+
:kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2), :v=>Dict{Symbol, Any}(:value=>:v_var, :type=>Real), :v_array=>Dict{Symbol, Union{Nothing, UnionAll}}(:value=>nothing, :type=>AbstractArray{Real}), :p1=>Dict(:value=>nothing))
261+
:structural_parameters => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2))
256262
:independent_variable => t
257-
:constants => Dict{Symbol, Dict}(:c=>Dict(:value=>1))
263+
:constants => Dict{Symbol, Dict}(:c=>Dict{Symbol, Any}(:value=>1, :type=>Int64, :description=>"Example constant."))
258264
:extend => Any[[:p2, :p1], Symbol("#mtkmodel__anonymous__ModelB"), :ModelB]
259265
:equations => Any["model_a.k ~ f(v)"]
260266
```

src/systems/abstractsystem.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,7 +1391,7 @@ function _named(name, call, runtime = false)
13911391
end
13921392
end
13931393

1394-
function _named_idxs(name::Symbol, idxs, call)
1394+
function _named_idxs(name::Symbol, idxs, call; extra_args = "")
13951395
if call.head !== :->
13961396
throw(ArgumentError("Not an anonymous function"))
13971397
end
@@ -1402,7 +1402,10 @@ function _named_idxs(name::Symbol, idxs, call)
14021402
ex = Base.Cartesian.poplinenum(ex)
14031403
ex = _named(:(Symbol($(Meta.quot(name)), :_, $sym)), ex, true)
14041404
ex = Base.Cartesian.poplinenum(ex)
1405-
:($name = $map($sym -> $ex, $idxs))
1405+
:($name = map($sym -> begin
1406+
$extra_args
1407+
$ex
1408+
end, $idxs))
14061409
end
14071410

14081411
function single_named_expr(expr)

src/systems/model_parsing.jl

Lines changed: 89 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ struct Model{F, S}
2424
end
2525
(m::Model)(args...; kw...) = m.f(args...; kw...)
2626

27+
Base.parentmodule(m::Model) = parentmodule(m.f)
28+
2729
for f in (:connector, :mtkmodel)
2830
isconnector = f == :connector ? true : false
2931
@eval begin
@@ -40,7 +42,7 @@ function _model_macro(mod, name, expr, isconnector)
4042
:kwargs => Dict{Symbol, Dict}(),
4143
:structural_parameters => Dict{Symbol, Dict}()
4244
)
43-
comps = Symbol[]
45+
comps = Union{Symbol, Expr}[]
4446
ext = Ref{Any}(nothing)
4547
eqs = Expr[]
4648
icon = Ref{Union{String, URI}}()
@@ -745,7 +747,7 @@ end
745747

746748
### Parsing Components:
747749

748-
function component_args!(a, b, expr, varexpr, kwargs)
750+
function component_args!(a, b, varexpr, kwargs; index_name = nothing)
749751
# Whenever `b` is a function call, skip the first arg aka the function name.
750752
# Whenever it is a kwargs list, include it.
751753
start = b.head == :call ? 2 : 1
@@ -754,73 +756,115 @@ function component_args!(a, b, expr, varexpr, kwargs)
754756
arg isa LineNumberNode && continue
755757
MLStyle.@match arg begin
756758
x::Symbol || Expr(:kw, x) => begin
757-
_v = _rename(a, x)
758-
b.args[i] = Expr(:kw, x, _v)
759-
push!(varexpr.args, :((@isdefined $x) && ($_v = $x)))
760-
push!(kwargs, Expr(:kw, _v, nothing))
761-
# dict[:kwargs][_v] = nothing
759+
varname, _varname = _rename(a, x)
760+
b.args[i] = Expr(:kw, x, _varname)
761+
push!(varexpr.args, :((if $varname !== nothing
762+
$_varname = $varname
763+
elseif @isdefined $x
764+
# Allow users to define a var in `structural_parameters` and set
765+
# that as positional arg of subcomponents; it is useful for cases
766+
# where it needs to be passed to multiple subcomponents.
767+
$_varname = $x
768+
end)))
769+
push!(kwargs, Expr(:kw, varname, nothing))
770+
# dict[:kwargs][varname] = nothing
762771
end
763772
Expr(:parameters, x...) => begin
764-
component_args!(a, arg, expr, varexpr, kwargs)
773+
component_args!(a, arg, varexpr, kwargs)
765774
end
766775
Expr(:kw, x, y) => begin
767-
_v = _rename(a, x)
768-
b.args[i] = Expr(:kw, x, _v)
769-
push!(varexpr.args, :($_v = $_v === nothing ? $y : $_v))
770-
push!(kwargs, Expr(:kw, _v, nothing))
771-
# dict[:kwargs][_v] = nothing
776+
varname, _varname = _rename(a, x)
777+
b.args[i] = Expr(:kw, x, _varname)
778+
if isnothing(index_name)
779+
push!(varexpr.args, :($_varname = $varname === nothing ? $y : $varname))
780+
else
781+
push!(varexpr.args,
782+
:($_varname = $varname === nothing ? $y : $varname[$index_name]))
783+
end
784+
push!(kwargs, Expr(:kw, varname, nothing))
785+
# dict[:kwargs][varname] = nothing
772786
end
773787
_ => error("Could not parse $arg of component $a")
774788
end
775789
end
776790
end
777791

778-
function _parse_components!(exprs, body, kwargs)
779-
expr = Expr(:block)
792+
model_name(name, range) = Symbol.(name, :_, collect(range))
793+
794+
function _parse_components!(body, kwargs)
795+
local expr
780796
varexpr = Expr(:block)
781-
# push!(exprs, varexpr)
782-
comps = Vector{Union{Symbol, Expr}}[]
797+
comps = Vector{Union{Union{Expr, Symbol}, Expr}}[]
783798
comp_names = []
784799

785-
for arg in body.args
786-
arg isa LineNumberNode && continue
787-
MLStyle.@match arg begin
788-
Expr(:block) => begin
789-
# TODO: Do we need this?
790-
error("Multiple `@components` block detected within a single block")
791-
end
792-
Expr(:(=), a, b) => begin
793-
arg = deepcopy(arg)
794-
b = deepcopy(arg.args[2])
800+
Base.remove_linenums!(body)
801+
arg = body.args[end]
795802

796-
component_args!(a, b, expr, varexpr, kwargs)
803+
MLStyle.@match arg begin
804+
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:(=), c, d)))) => begin
805+
array_varexpr = Expr(:block)
797806

798-
arg.args[2] = b
799-
push!(expr.args, arg)
800-
push!(comp_names, a)
801-
if (isa(b.args[1], Symbol) || Meta.isexpr(b.args[1], :.))
802-
push!(comps, [a, b.args[1]])
803-
end
807+
push!(comp_names, :($a...))
808+
push!(comps, [a, b.args[1], d])
809+
b = deepcopy(b)
810+
811+
component_args!(a, b, array_varexpr, kwargs; index_name = c)
812+
813+
expr = _named_idxs(a, d, :($c -> $b); extra_args = array_varexpr)
814+
end
815+
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:filter, e, Expr(:(=), c, d))))) => begin
816+
error("List comprehensions with conditional statements aren't supported.")
817+
end
818+
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:(=), c, d), e...))) => begin
819+
# Note that `e` is of the form `Tuple{Expr(:(=), c, d)}`
820+
error("More than one index isn't supported while building component array")
821+
end
822+
Expr(:block) => begin
823+
# TODO: Do we need this?
824+
error("Multiple `@components` block detected within a single block")
825+
end
826+
Expr(:(=), a, Expr(:for, Expr(:(=), c, d), b)) => begin
827+
Base.remove_linenums!(b)
828+
array_varexpr = Expr(:block)
829+
push!(array_varexpr.args, b.args[1:(end - 1)]...)
830+
push!(comp_names, :($a...))
831+
push!(comps, [a, b.args[end].args[1], d])
832+
b = deepcopy(b)
833+
834+
component_args!(a, b.args[end], array_varexpr, kwargs; index_name = c)
835+
836+
expr = _named_idxs(a, d, :($c -> $(b.args[end])); extra_args = array_varexpr)
837+
end
838+
Expr(:(=), a, b) => begin
839+
arg = deepcopy(arg)
840+
b = deepcopy(arg.args[2])
841+
842+
component_args!(a, b, varexpr, kwargs)
843+
844+
arg.args[2] = b
845+
expr = :(@named $arg)
846+
push!(comp_names, a)
847+
if (isa(b.args[1], Symbol) || Meta.isexpr(b.args[1], :.))
848+
push!(comps, [a, b.args[1]])
804849
end
805-
_ => error("Couldn't parse the component body: $arg")
806850
end
851+
_ => error("Couldn't parse the component body: $arg")
807852
end
853+
808854
return comp_names, comps, expr, varexpr
809855
end
810856

811857
function push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
812858
blk = Expr(:block)
813859
push!(blk.args, varexpr)
814-
push!(blk.args, :(@named begin
815-
$(expr_vec.args...)
816-
end))
860+
push!(blk.args, expr_vec)
817861
push!(blk.args, :($push!(systems, $(comp_names...))))
818862
push!(ifexpr.args, blk)
819863
end
820864

821865
function handle_if_x!(mod, exprs, ifexpr, x, kwargs, condition = nothing)
822866
push!(ifexpr.args, condition)
823-
comp_names, comps, expr_vec, varexpr = _parse_components!(ifexpr, x, kwargs)
867+
comp_names, comps, expr_vec, varexpr = _parse_components!(x, kwargs)
824868
push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
825869
comps
826870
end
@@ -836,7 +880,7 @@ function handle_if_y!(exprs, ifexpr, y, kwargs)
836880
push!(ifexpr.args, elseifexpr)
837881
(comps...,)
838882
else
839-
comp_names, comps, expr_vec, varexpr = _parse_components!(exprs, y, kwargs)
883+
comp_names, comps, expr_vec, varexpr = _parse_components!(y, kwargs)
840884
push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
841885
comps
842886
end
@@ -861,25 +905,23 @@ function parse_components!(exprs, cs, dict, compbody, kwargs)
861905
Expr(:if, condition, x, y) => begin
862906
handle_conditional_components(condition, dict, exprs, kwargs, x, y)
863907
end
864-
Expr(:(=), a, b) => begin
865-
comp_names, comps, expr_vec, varexpr = _parse_components!(exprs,
866-
:(begin
908+
# Either the arg is top level component declaration or an invalid cause - both are handled by `_parse_components`
909+
_ => begin
910+
comp_names, comps, expr_vec, varexpr = _parse_components!(:(begin
867911
$arg
868912
end),
869913
kwargs)
870914
push!(cs, comp_names...)
871915
push!(dict[:components], comps...)
872-
push!(exprs, varexpr, :(@named begin
873-
$(expr_vec.args...)
874-
end))
916+
push!(exprs, varexpr, expr_vec)
875917
end
876-
_ => error("Couldn't parse the component body $compbody")
877918
end
878919
end
879920
end
880921

881922
function _rename(compname, varname)
882923
compname = Symbol(compname, :__, varname)
924+
(compname, Symbol(:_, compname))
883925
end
884926

885927
# Handle top level branching

test/model_parsing.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,3 +650,72 @@ end
650650
@named m = MyModel()
651651
@variables x___(t)
652652
@test isequal(x___, _b[])
653+
654+
@testset "Component array" begin
655+
@mtkmodel SubComponent begin
656+
@parameters begin
657+
sc
658+
end
659+
end
660+
661+
@mtkmodel Component begin
662+
@structural_parameters begin
663+
N = 2
664+
end
665+
@components begin
666+
comprehension = [SubComponent(sc = i) for i in 1:N]
667+
written_out_for = for i in 1:N
668+
sc = i + 1
669+
SubComponent(; sc)
670+
end
671+
single_sub_component = SubComponent()
672+
end
673+
end
674+
675+
@named component = Component()
676+
component = complete(component)
677+
678+
@test nameof.(ModelingToolkit.get_systems(component)) == [
679+
:comprehension_1,
680+
:comprehension_2,
681+
:written_out_for_1,
682+
:written_out_for_2,
683+
:single_sub_component
684+
]
685+
686+
@test getdefault(component.comprehension_1.sc) == 1
687+
@test getdefault(component.comprehension_2.sc) == 2
688+
@test getdefault(component.written_out_for_1.sc) == 2
689+
@test getdefault(component.written_out_for_2.sc) == 3
690+
691+
@mtkmodel ConditionalComponent begin
692+
@structural_parameters begin
693+
N = 2
694+
end
695+
@components begin
696+
if N == 2
697+
if_comprehension = [SubComponent(sc = i) for i in 1:N]
698+
elseif N == 3
699+
elseif_comprehension = [SubComponent(sc = i) for i in 1:N]
700+
else
701+
else_comprehension = [SubComponent(sc = i) for i in 1:N]
702+
end
703+
end
704+
end
705+
706+
@named if_component = ConditionalComponent()
707+
@test nameof.(get_systems(if_component)) == [:if_comprehension_1, :if_comprehension_2]
708+
709+
@named elseif_component = ConditionalComponent(; N = 3)
710+
@test nameof.(get_systems(elseif_component)) ==
711+
[:elseif_comprehension_1, :elseif_comprehension_2, :elseif_comprehension_3]
712+
713+
@named else_component = ConditionalComponent(; N = 4)
714+
@test nameof.(get_systems(else_component)) ==
715+
[:else_comprehension_1, :else_comprehension_2,
716+
:else_comprehension_3, :else_comprehension_4]
717+
end
718+
719+
@testset "Parent module of Models" begin
720+
@test parentmodule(MyMockModule.Ground) == MyMockModule
721+
end

0 commit comments

Comments
 (0)