Skip to content

Commit ce73940

Browse files
committed
feat: extend all arguments of a base sys to sys
1 parent c1b1af1 commit ce73940

File tree

3 files changed

+84
-31
lines changed

3 files changed

+84
-31
lines changed

docs/src/basics/MTKLanguage.md

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ end
7373
v_array(t)[1:N, 1:M]
7474
v_for_defaults(t)
7575
end
76-
@extend ModelB(; p1)
76+
@extend ModelB(p1 = 1)
7777
@components begin
7878
model_a = ModelA(; k_array)
7979
model_array_a = [ModelA(; k = i) for i in 1:N]
@@ -149,14 +149,18 @@ julia> ModelingToolkit.getdefault(model_c1.v)
149149

150150
#### `@extend` begin block
151151

152-
- Partial systems can be extended in a higher system as `@extend PartialSystem(; kwargs)`.
153-
- Keyword arguments pf partial system in the `@extend` definition are added as the keyword arguments of the base system.
154-
- Note that in above example, `p1` is promoted as an argument of `ModelC`. Users can set the value of `p1`. However, as `p2` isn't listed in the model definition, its initial guess can't be specified while creating an instance of `ModelC`.
152+
Partial systems can be extended in a higher system in two ways:
155153

156-
```julia
157-
julia> @mtkbuild model_c2 = ModelC(; p1 = 2.0)
154+
- `@extend PartialSystem(var1 = value1)`
155+
156+
+ This is the recommended way of extending a base system.
157+
+ The default values for the arguments of the base system can be declared in the `@extend` statement.
158+
+ Note that all keyword arguments of the base system are added as the keyword arguments of the main system.
158159

159-
```
160+
- `@extend var_to_unpack1, var_to_unpack2 = partial_sys = PartialSystem(var1 = value1)`
161+
162+
+ In this method: explicitly list the variables that should be unpacked, provide a name for the partial system and declare the base system.
163+
+ Note that only the arguments listed out in the declaration of the base system (here: `var1`) are added as the keyword arguments of the higher system.
160164

161165
#### `@components` begin block
162166

@@ -325,11 +329,11 @@ For example, the structure of `ModelC` is:
325329
julia> ModelC.structure
326330
Dict{Symbol, Any} with 10 entries:
327331
:components => Any[Union{Expr, Symbol}[:model_a, :ModelA], Union{Expr, Symbol}[:model_array_a, :ModelA, :(1:N)], Union{Expr, Symbol}[:model_array_b, :ModelA, :(1:N)]]
328-
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_for_defaults=>Dict(:type=>Real))
332+
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_array=>Dict(:value=>nothing, :type=>Real, :size=>(:N, :M)), :v_for_defaults=>Dict(:type=>Real))
329333
:icon => URI("https://github.com/SciML/SciMLDocs/blob/main/docs/src/assets/logo.png")
330-
:kwargs => Dict{Symbol, Dict}(:f => Dict(:value => :sin), :N => Dict(:value => 2), :M => Dict(:value => 3), :v => Dict{Symbol, Any}(:value => :v_var, :type => Real), :v_for_defaults => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real), :p1 => Dict(:value => nothing)),
331-
:structural_parameters => Dict{Symbol, Dict}(:f => Dict(:value => :sin), :N => Dict(:value => 2), :M => Dict(:value => 3))
332-
:independent_variable => t
334+
:kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :p2=>Dict(:value=>NoValue()), :N=>Dict(:value=>2), :M=>Dict(:value=>3), :v=>Dict{Symbol, Any}(:value=>:v_var, :type=>Real), :v_array=>Dict{Symbol, Any}(:value=>nothing, :type=>Real, :size=>(:N, :M)), :v_for_defaults=>Dict{Symbol, Union{Nothing, DataType}}(:value=>nothing, :type=>Real), :p1=>Dict(:value=>1))
335+
:structural_parameters => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2), :M=>Dict(:value=>3))
336+
:independent_variable => :t
333337
:constants => Dict{Symbol, Dict}(:c=>Dict{Symbol, Any}(:value=>1, :type=>Int64, :description=>"Example constant."))
334338
:extend => Any[[:p2, :p1], Symbol("#mtkmodel__anonymous__ModelB"), :ModelB]
335339
:defaults => Dict{Symbol, Any}(:v_for_defaults=>2.0)

src/systems/model_parsing.jl

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ function parse_structural_parameters!(exprs, sps, dict, mod, body, kwargs)
717717
end
718718
end
719719

720-
function extend_args!(a, b, dict, expr, kwargs, varexpr, has_param = false)
720+
function extend_args!(a, b, dict, expr, kwargs, has_param = false)
721721
# Whenever `b` is a function call, skip the first arg aka the function name.
722722
# Whenever it is a kwargs list, include it.
723723
start = b.head == :call ? 2 : 1
@@ -738,18 +738,18 @@ function extend_args!(a, b, dict, expr, kwargs, varexpr, has_param = false)
738738
dict[:kwargs][x] = Dict(:value => nothing)
739739
end
740740
Expr(:kw, x) => begin
741+
b.args[i] = Expr(:kw, x, x)
741742
push!(kwargs, Expr(:kw, x, nothing))
742743
dict[:kwargs][x] = Dict(:value => nothing)
743744
end
744745
Expr(:kw, x, y) => begin
745746
b.args[i] = Expr(:kw, x, x)
746-
push!(varexpr.args, :($x = $x === nothing ? $y : $x))
747-
push!(kwargs, Expr(:kw, x, nothing))
748-
dict[:kwargs][x] = Dict(:value => nothing)
747+
push!(kwargs, Expr(:kw, x, y))
748+
dict[:kwargs][x] = Dict(:value => y)
749749
end
750750
Expr(:parameters, x...) => begin
751751
has_param = true
752-
extend_args!(a, arg, dict, expr, kwargs, varexpr, has_param)
752+
extend_args!(a, arg, dict, expr, kwargs, has_param)
753753
end
754754
_ => error("Could not parse $arg of component $a")
755755
end
@@ -758,17 +758,40 @@ end
758758

759759
const EMPTY_DICT = Dict()
760760
const EMPTY_VoVoSYMBOL = Vector{Symbol}[]
761+
const EMPTY_VoVoVoSYMBOL = Vector{Symbol}[[]]
761762

762-
function Base.names(model::Model)
763+
function _arguments(model::Model)
763764
vars = keys(get(model.structure, :variables, EMPTY_DICT))
764765
vars = union(vars, keys(get(model.structure, :parameters, EMPTY_DICT)))
765-
vars = union(vars,
766-
map(first, get(model.structure, :components, EMPTY_VoVoSYMBOL)))
766+
vars = union(vars, first(get(model.structure, :extend, EMPTY_VoVoVoSYMBOL)))
767767
collect(vars)
768768
end
769769

770-
function _parse_extend!(ext, a, b, dict, expr, kwargs, varexpr, vars)
771-
extend_args!(a, b, dict, expr, kwargs, varexpr)
770+
function Base.names(model::Model)
771+
collect(union(_arguments(model),
772+
map(first, get(model.structure, :components, EMPTY_VoVoSYMBOL))))
773+
end
774+
775+
function _parse_extend!(ext, a, b, dict, expr, kwargs, vars, additional_args)
776+
extend_args!(a, b, dict, expr, kwargs)
777+
778+
# `additional_args` doubles as a flag to check the mode of `@extend`. It is
779+
# `nothing` for explicit destructuring.
780+
# The following block modifies the arguments of both base and higher systems
781+
# for the implicit extend statements.
782+
if additional_args !== nothing
783+
b.args = [b.args[1]]
784+
allvars = [additional_args.args..., vars.args...]
785+
push!(b.args, Expr(:parameters))
786+
for var in allvars
787+
push!(b.args[end].args, var)
788+
if !haskey(dict[:kwargs], var)
789+
push!(dict[:kwargs], var => Dict(:value => NO_VALUE))
790+
push!(kwargs, Expr(:kw, var, NO_VALUE))
791+
end
792+
end
793+
end
794+
772795
ext[] = a
773796
push!(b.args, Expr(:kw, :name, Meta.quot(a)))
774797
push!(expr.args, :($a = $b))
@@ -780,8 +803,6 @@ end
780803

781804
function parse_extend!(exprs, ext, dict, mod, body, kwargs)
782805
expr = Expr(:block)
783-
varexpr = Expr(:block)
784-
push!(exprs, varexpr)
785806
push!(exprs, expr)
786807
body = deepcopy(body)
787808
MLStyle.@match body begin
@@ -792,7 +813,9 @@ function parse_extend!(exprs, ext, dict, mod, body, kwargs)
792813
error("`@extend` destructuring only takes an tuple as LHS. Got $body")
793814
end
794815
a, b = b.args
795-
_parse_extend!(ext, a, b, dict, expr, kwargs, varexpr, vars)
816+
# This doubles as a flag to identify the mode of `@extend`
817+
additional_args = nothing
818+
_parse_extend!(ext, a, b, dict, expr, kwargs, vars, additional_args)
796819
else
797820
error("When explicitly destructing in `@extend` please use the syntax: `@extend a, b = oneport = OnePort()`.")
798821
end
@@ -802,8 +825,11 @@ function parse_extend!(exprs, ext, dict, mod, body, kwargs)
802825
b = body
803826
if (model = getproperty(mod, b.args[1])) isa Model
804827
vars = Expr(:tuple)
805-
append!(vars.args, names(model))
806-
_parse_extend!(ext, a, b, dict, expr, kwargs, varexpr, vars)
828+
append!(vars.args, _arguments(model))
829+
additional_args = Expr(:tuple)
830+
append!(additional_args.args,
831+
keys(get(model.structure, :structural_parameters, EMPTY_DICT)))
832+
_parse_extend!(ext, a, b, dict, expr, kwargs, vars, additional_args)
807833
else
808834
error("Cannot infer the exact `Model` that `@extend $(body)` refers." *
809835
" Please specify the names that it brings into scope by:" *
@@ -1104,7 +1130,7 @@ function parse_icon!(body::String, dict, icon, mod)
11041130
icon_dir = get(ENV, "MTK_ICONS_DIR", joinpath(DEPOT_PATH[1], "mtk_icons"))
11051131
dict[:icon] = icon[] = if isfile(body)
11061132
URI("file:///" * abspath(body))
1107-
elseif (iconpath = joinpath(icon_dir, body); isfile(iconpath))
1133+
elseif (iconpath = abspath(joinpath(icon_dir, body)); isfile(iconpath))
11081134
URI("file:///" * abspath(iconpath))
11091135
elseif try
11101136
Base.isvalid(URI(body))
@@ -1115,6 +1141,7 @@ function parse_icon!(body::String, dict, icon, mod)
11151141
elseif (_body = lstrip(body); startswith(_body, r"<\?xml|<svg"))
11161142
String(_body) # With Julia-1.10 promoting `SubString{String}` to `String` can be dropped.
11171143
else
1144+
@info iconpath=joinpath(icon_dir, body) isfile(iconpath) body
11181145
error("\n$body is not a valid icon")
11191146
end
11201147
end

test/model_parsing.jl

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ R_val = 20u"Ω"
153153
res__R = 100u"Ω"
154154
@mtkbuild rc = RC(; C_val, R_val, resistor.R = res__R)
155155
prob = ODEProblem(rc, [], (0, 1e9))
156-
sol = solve(prob, Rodas5P())
156+
sol = solve(prob)
157157
defs = ModelingToolkit.defaults(rc)
158158
@test sol[rc.capacitor.v, end] defs[rc.constant.k]
159159
resistor = getproperty(rc, :resistor; namespace = false)
@@ -459,9 +459,9 @@ end
459459
@test A.structure[:parameters] == Dict(:p => Dict(:type => Real))
460460
@test A.structure[:extend] == [[:e], :extended_e, :E]
461461
@test A.structure[:equations] == ["e ~ 0"]
462-
@test A.structure[:kwargs] ==
463-
Dict{Symbol, Dict}(:p => Dict(:value => nothing, :type => Real),
464-
:v => Dict(:value => nothing, :type => Real))
462+
@test A.structure[:kwargs] == Dict{Symbol, Dict}(
463+
:p => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real),
464+
:v => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real))
465465
@test A.structure[:components] == [[:cc, :C]]
466466
end
467467

@@ -910,3 +910,25 @@ end
910910
end),
911911
false)
912912
end
913+
914+
@mtkmodel BaseSys begin
915+
@parameters begin
916+
p1
917+
p2
918+
end
919+
@variables begin
920+
v1(t)
921+
end
922+
end
923+
924+
@testset "Arguments of base system" begin
925+
@mtkmodel MainSys begin
926+
@extend BaseSys(p1 = 1)
927+
end
928+
929+
@test names(MainSys) == [:p2, :p1, :v1]
930+
@named main_sys = MainSys(p1 = 11, p2 = 12, v1 = 13)
931+
@test getdefault(main_sys.p1) == 11
932+
@test getdefault(main_sys.p2) == 12
933+
@test getdefault(main_sys.v1) == 13
934+
end

0 commit comments

Comments
 (0)