Skip to content

Commit 551cf0d

Browse files
committed
feat: support multiple extend statements in MTKModel
1 parent f20fb5f commit 551cf0d

File tree

4 files changed

+50
-13
lines changed

4 files changed

+50
-13
lines changed

docs/src/basics/MTKLanguage.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ julia> ModelingToolkit.getdefault(model_c1.v)
147147
2.0
148148
```
149149

150-
#### `@extend` begin block
150+
#### `@extend` statement
151151

152-
Partial systems can be extended in a higher system in two ways:
152+
One or more partial systems can be extended in a higher system with `@extend` statements. This can be done in two ways:
153153

154154
- `@extend PartialSystem(var1 = value1)`
155155

@@ -313,7 +313,8 @@ end
313313
- `:components`: The list of sub-components in the form of [[name, sub_component_name],...].
314314
- `:constants`: Dictionary of constants mapped to its metadata.
315315
- `:defaults`: Dictionary of variables and default values specified in the `@defaults`.
316-
- `:extend`: The list of extended unknowns, name given to the base system, and name of the base system.
316+
- `:extend`: The list of extended unknowns, parameters and components, name given to the base system, and name of the base system.
317+
When multiple extend statements are present, latter two are returned as lists.
317318
- `:structural_parameters`: Dictionary of structural parameters mapped to their metadata.
318319
- `:parameters`: Dictionary of symbolic parameters mapped to their metadata. For
319320
parameter arrays, length is added to the metadata as `:size`.

src/systems/abstractsystem.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,7 @@ Mark a system as completed. A completed system is a system which is done being
918918
defined/modified and is ready for structural analysis or other transformations.
919919
This allows for analyses and optimizations to be performed which require knowing
920920
the global structure of the system.
921-
921+
922922
One property to note is that if a system is complete, the system will no longer
923923
namespace its subsystems or variables, i.e. `isequal(complete(sys).v.i, v.i)`.
924924
"""
@@ -1933,7 +1933,7 @@ function Base.show(
19331933
end
19341934
end
19351935
limited = nrows < nsubs
1936-
limited && print(io, "\n") # too many to print
1936+
limited && print(io, "\n") # too many to print
19371937

19381938
# Print equations
19391939
eqs = equations(sys)
@@ -3043,10 +3043,19 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem;
30433043
return T(args...; kwargs...)
30443044
end
30453045

3046+
function extend(sys, basesys::Vector{T}) where {T <: AbstractSystem}
3047+
foldl(extend, basesys, init = sys)
3048+
end
3049+
30463050
function Base.:(&)(sys::AbstractSystem, basesys::AbstractSystem; kwargs...)
30473051
extend(sys, basesys; kwargs...)
30483052
end
30493053

3054+
function Base.:(&)(
3055+
sys::AbstractSystem, basesys::Vector{T}; kwargs...) where {T <: AbstractSystem}
3056+
extend(sys, basesys; kwargs...)
3057+
end
3058+
30503059
"""
30513060
$(SIGNATURES)
30523061

src/systems/model_parsing.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function _model_macro(mod, name, expr, isconnector)
5050
:structural_parameters => Dict{Symbol, Dict}()
5151
)
5252
comps = Union{Symbol, Expr}[]
53-
ext = Ref{Any}(nothing)
53+
ext = []
5454
eqs = Expr[]
5555
icon = Ref{Union{String, URI}}()
5656
ps, sps, vs, = [], [], []
@@ -115,10 +115,10 @@ function _model_macro(mod, name, expr, isconnector)
115115
sys = :($ODESystem($(flatten_equations)(equations), $iv, variables, parameters;
116116
name, systems, gui_metadata = $gui_metadata, defaults))
117117

118-
if ext[] === nothing
118+
if length(ext) == 0
119119
push!(exprs.args, :(var"#___sys___" = $sys))
120120
else
121-
push!(exprs.args, :(var"#___sys___" = $extend($sys, $(ext[]))))
121+
push!(exprs.args, :(var"#___sys___" = $extend($sys, [$(ext...)])))
122122
end
123123

124124
isconnector && push!(exprs.args,
@@ -240,7 +240,7 @@ function unit_handled_variable_value(meta, varname)
240240
end
241241

242242
# This function parses various variable/parameter definitions.
243-
#
243+
#
244244
# The comments indicate the syntax matched by a block; either when parsed directly
245245
# when it is called recursively for parsing a part of an expression.
246246
# These variable definitions are part of test suite in `test/model_parsing.jl`
@@ -286,7 +286,7 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
286286
# `(l2(t)[1:N, 1:M] = 2), [description = "l is more than 1D, with arbitrary length"]`
287287
# `(l3(t)[1:3] = 3), [description = "l2 is 1D"]`
288288
# `(l4(t)[1:N] = 4), [description = "l2 is 1D, with arbitrary length"]`
289-
#
289+
#
290290
# Condition 2 parses:
291291
# `(l5(t)[1:3]::Int = 5), [description = "l3 is 1D and has a type"]`
292292
# `(l6(t)[1:N]::Int = 6), [description = "l3 is 1D and has a type, with arbitrary length"]`
@@ -373,7 +373,7 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, where_types;
373373
# Condition 1 is recursively called by:
374374
# `par5[1:3]::BigFloat`
375375
# `par6(t)[1:3]::BigFloat`
376-
#
376+
#
377377
# Condition 2 parses:
378378
# `b2(t)[1:2]`
379379
# `a2[1:2]`
@@ -791,11 +791,17 @@ function _parse_extend!(ext, a, b, dict, expr, kwargs, vars, implicit_arglist)
791791
end
792792
end
793793

794-
ext[] = a
794+
push!(ext, a)
795795
push!(b.args, Expr(:kw, :name, Meta.quot(a)))
796796
push!(expr.args, :($a = $b))
797797

798-
dict[:extend] = [Symbol.(vars.args), a, b.args[1]]
798+
if !haskey(dict, :extend)
799+
dict[:extend] = [Symbol.(vars.args), a, b.args[1]]
800+
else
801+
push!(dict[:extend][1], Symbol.(vars.args)...)
802+
dict[:extend][2] = vcat(dict[:extend][2], a)
803+
dict[:extend][3] = vcat(dict[:extend][3], b.args[1])
804+
end
799805

800806
push!(expr.args, :(@unpack $vars = $a))
801807
end

test/model_parsing.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,15 @@ end
945945
end
946946
end
947947

948+
@mtkmodel MidModelB begin
949+
@parameters begin
950+
b
951+
end
952+
@components begin
953+
inmodel_b = InnerModel()
954+
end
955+
end
956+
948957
@mtkmodel OuterModel begin
949958
@extend MidModel()
950959
@equations begin
@@ -958,3 +967,15 @@ end
958967
@named out = OuterModel()
959968
@test OuterModel.structure[:extend][1] == [:inmodel]
960969
end
970+
971+
@mtkmodel MultipleExtend begin
972+
@extend MidModel()
973+
@extend MidModelB()
974+
end
975+
976+
@testset "Multiple extend statements" begin
977+
@named multiple_extend = MultipleExtend()
978+
@test collect(nameof.(multiple_extend.systems)) == [:inmodel_b, :inmodel]
979+
@test MultipleExtend.structure[:extend][1] == [:inmodel, :b, :inmodel_b]
980+
@test tosymbol.(parameters(multiple_extend)) == [:b, :inmodel_b₊p, :inmodel₊p]
981+
end

0 commit comments

Comments
 (0)