Skip to content

Commit 314bb1d

Browse files
authored
Merge pull request #2306 from SciML/myb/extend
Support implicit name unpack in `at extend`
2 parents 379161f + 73f9e0a commit 314bb1d

File tree

2 files changed

+144
-135
lines changed

2 files changed

+144
-135
lines changed

src/systems/model_parsing.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps,
237237
if mname == Symbol("@components")
238238
parse_components!(exprs, comps, dict, body, kwargs)
239239
elseif mname == Symbol("@extend")
240-
parse_extend!(exprs, ext, dict, body, kwargs)
240+
parse_extend!(exprs, ext, dict, mod, body, kwargs)
241241
elseif mname == Symbol("@variables")
242242
parse_variables!(exprs, vs, dict, mod, body, :variables, kwargs)
243243
elseif mname == Symbol("@parameters")
@@ -372,24 +372,35 @@ function extend_args!(a, b, dict, expr, kwargs, varexpr, has_param = false)
372372
end
373373
end
374374

375-
function parse_extend!(exprs, ext, dict, body, kwargs)
375+
function parse_extend!(exprs, ext, dict, mod, body, kwargs)
376376
expr = Expr(:block)
377377
varexpr = Expr(:block)
378378
push!(exprs, varexpr)
379379
push!(exprs, expr)
380380
body = deepcopy(body)
381381
MLStyle.@match body begin
382382
Expr(:(=), a, b) => begin
383-
vars = nothing
384383
if Meta.isexpr(b, :(=))
385384
vars = a
386385
if !Meta.isexpr(vars, :tuple)
387386
error("`@extend` destructuring only takes an tuple as LHS. Got $body")
388387
end
389388
a, b = b.args
390-
extend_args!(a, b, dict, expr, kwargs, varexpr)
391-
vars, a, b
389+
elseif Meta.isexpr(b, :call)
390+
if (model = getproperty(mod, b.args[1])) isa Model
391+
_vars = keys(get(model.structure, :variables, Dict()))
392+
_vars = union(_vars, keys(get(model.structure, :parameters, Dict())))
393+
_vars = union(_vars,
394+
map(first, get(model.structure, :components, Vector{Symbol}[])))
395+
vars = Expr(:tuple)
396+
append!(vars.args, collect(_vars))
397+
else
398+
error("Cannot infer the exact `Model` that `@extend $(body)` refers." *
399+
" Please specify the names that it brings into scope by:" *
400+
" `@extend a, b = oneport = OnePort()`.")
401+
end
392402
end
403+
extend_args!(a, b, dict, expr, kwargs, varexpr)
393404
ext[] = a
394405
push!(b.args, Expr(:kw, :name, Meta.quot(a)))
395406
push!(expr.args, :($a = $b))

test/model_parsing.jl

Lines changed: 128 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -7,77 +7,76 @@ using Unitful
77

88
ENV["MTK_ICONS_DIR"] = "$(@__DIR__)/icons"
99

10-
@testset "Comprehensive Test of Parsing Models (with an RC Circuit)" begin
11-
@connector RealInput begin
12-
u(t), [input = true, unit = u"V"]
10+
@connector RealInput begin
11+
u(t), [input = true, unit = u"V"]
12+
end
13+
@connector RealOutput begin
14+
u(t), [output = true, unit = u"V"]
15+
end
16+
@mtkmodel Constant begin
17+
@components begin
18+
output = RealOutput()
1319
end
14-
@connector RealOutput begin
15-
u(t), [output = true, unit = u"V"]
20+
@parameters begin
21+
k, [description = "Constant output value of block"]
1622
end
17-
@mtkmodel Constant begin
18-
@components begin
19-
output = RealOutput()
20-
end
21-
@parameters begin
22-
k, [description = "Constant output value of block"]
23-
end
24-
@equations begin
25-
output.u ~ k
26-
end
23+
@equations begin
24+
output.u ~ k
2725
end
26+
end
2827

29-
@variables t [unit = u"s"]
30-
D = Differential(t)
28+
@variables t [unit = u"s"]
29+
D = Differential(t)
3130

32-
@connector Pin begin
33-
v(t), [unit = u"V"] # Potential at the pin [V]
34-
i(t), [connect = Flow, unit = u"A"] # Current flowing into the pin [A]
35-
@icon "pin.png"
36-
end
31+
@connector Pin begin
32+
v(t), [unit = u"V"] # Potential at the pin [V]
33+
i(t), [connect = Flow, unit = u"A"] # Current flowing into the pin [A]
34+
@icon "pin.png"
35+
end
3736

38-
@named p = Pin(; v = π)
39-
@test getdefault(p.v) == π
40-
@test Pin.isconnector == true
37+
@named p = Pin(; v = π)
38+
@test getdefault(p.v) == π
39+
@test Pin.isconnector == true
4140

42-
@mtkmodel OnePort begin
43-
@components begin
44-
p = Pin()
45-
n = Pin()
46-
end
47-
@variables begin
48-
v(t), [unit = u"V"]
49-
i(t), [unit = u"A"]
50-
end
51-
@icon "oneport.png"
52-
@equations begin
53-
v ~ p.v - n.v
54-
0 ~ p.i + n.i
55-
i ~ p.i
56-
end
41+
@mtkmodel OnePort begin
42+
@components begin
43+
p = Pin()
44+
n = Pin()
45+
end
46+
@variables begin
47+
v(t), [unit = u"V"]
48+
i(t), [unit = u"A"]
5749
end
50+
@icon "oneport.png"
51+
@equations begin
52+
v ~ p.v - n.v
53+
0 ~ p.i + n.i
54+
i ~ p.i
55+
end
56+
end
5857

59-
@test OnePort.isconnector == false
58+
@test OnePort.isconnector == false
6059

61-
@mtkmodel Ground begin
62-
@components begin
63-
g = Pin()
64-
end
65-
@icon begin
66-
read(abspath(ENV["MTK_ICONS_DIR"], "ground.svg"), String)
67-
end
68-
@equations begin
69-
g.v ~ 0
70-
end
60+
@mtkmodel Ground begin
61+
@components begin
62+
g = Pin()
63+
end
64+
@icon begin
65+
read(abspath(ENV["MTK_ICONS_DIR"], "ground.svg"), String)
7166
end
67+
@equations begin
68+
g.v ~ 0
69+
end
70+
end
7271

73-
resistor_log = "$(@__DIR__)/logo/resistor.svg"
74-
@mtkmodel Resistor begin
75-
@extend v, i = oneport = OnePort()
76-
@parameters begin
77-
R, [unit = u""]
78-
end
79-
@icon begin
80-
"""<?xml version="1.0" encoding="UTF-8"?>
72+
resistor_log = "$(@__DIR__)/logo/resistor.svg"
73+
@mtkmodel Resistor begin
74+
@extend v, i = oneport = OnePort()
75+
@parameters begin
76+
R, [unit = u""]
77+
end
78+
@icon begin
79+
"""<?xml version="1.0" encoding="UTF-8"?>
8180
<svg xmlns="http://www.w3.org/2000/svg" width="80" height="30">
8281
<path d="M10 15
8382
l15 0
@@ -91,88 +90,87 @@ l2.5 -5
9190
l15 0" stroke="black" stroke-width="1" stroke-linejoin="bevel" fill="none"></path>
9291
</svg>
9392
"""
94-
end
95-
@equations begin
96-
v ~ i * R
97-
end
9893
end
94+
@equations begin
95+
v ~ i * R
96+
end
97+
end
9998

100-
@mtkmodel Capacitor begin
101-
@parameters begin
102-
C, [unit = u"F"]
103-
end
104-
@extend v, i = oneport = OnePort(; v = 0.0)
105-
@icon "https://upload.wikimedia.org/wikipedia/commons/7/78/Capacitor_symbol.svg"
106-
@equations begin
107-
D(v) ~ i / C
108-
end
99+
@mtkmodel Capacitor begin
100+
@parameters begin
101+
C, [unit = u"F"]
102+
end
103+
@extend oneport = OnePort(; v = 0.0)
104+
@icon "https://upload.wikimedia.org/wikipedia/commons/7/78/Capacitor_symbol.svg"
105+
@equations begin
106+
D(v) ~ i / C
109107
end
108+
end
110109

111-
@named capacitor = Capacitor(C = 10, v = 10.0)
112-
@test getdefault(capacitor.v) == 10.0
110+
@named capacitor = Capacitor(C = 10, v = 10.0)
111+
@test getdefault(capacitor.v) == 10.0
113112

114-
@mtkmodel Voltage begin
115-
@extend v, i = oneport = OnePort()
116-
@components begin
117-
V = RealInput()
118-
end
119-
@equations begin
120-
v ~ V.u
121-
end
113+
@mtkmodel Voltage begin
114+
@extend v, i = oneport = OnePort()
115+
@components begin
116+
V = RealInput()
122117
end
118+
@equations begin
119+
v ~ V.u
120+
end
121+
end
123122

124-
@mtkmodel RC begin
125-
@structural_parameters begin
126-
R_val = 10
127-
C_val = 10
128-
k_val = 10
129-
end
130-
@components begin
131-
resistor = Resistor(; R = R_val)
132-
capacitor = Capacitor(; C = C_val)
133-
source = Voltage()
134-
constant = Constant(; k = k_val)
135-
ground = Ground()
136-
end
137-
138-
@equations begin
139-
connect(constant.output, source.V)
140-
connect(source.p, resistor.p)
141-
connect(resistor.n, capacitor.p)
142-
connect(capacitor.n, source.n, ground.g)
143-
end
123+
@mtkmodel RC begin
124+
@structural_parameters begin
125+
R_val = 10
126+
C_val = 10
127+
k_val = 10
128+
end
129+
@components begin
130+
resistor = Resistor(; R = R_val)
131+
capacitor = Capacitor(; C = C_val)
132+
source = Voltage()
133+
constant = Constant(; k = k_val)
134+
ground = Ground()
144135
end
145136

146-
C_val = 20
147-
R_val = 20
148-
res__R = 100
149-
@mtkbuild rc = RC(; C_val, R_val, resistor.R = res__R)
150-
resistor = getproperty(rc, :resistor; namespace = false)
151-
@test getname(rc.resistor) === getname(resistor)
152-
@test getname(rc.resistor.R) === getname(resistor.R)
153-
@test getname(rc.resistor.v) === getname(resistor.v)
154-
# Test that `resistor.R` overrides `R_val` in the argument.
155-
@test getdefault(rc.resistor.R) == res__R != R_val
156-
# Test that `C_val` passed via argument is set as default of C.
157-
@test getdefault(rc.capacitor.C) == C_val
158-
# Test that `k`'s default value is unchanged.
159-
@test getdefault(rc.constant.k) == RC.structure[:kwargs][:k_val]
160-
@test getdefault(rc.capacitor.v) == 0.0
161-
162-
@test get_gui_metadata(rc.resistor).layout == Resistor.structure[:icon] ==
163-
read(joinpath(ENV["MTK_ICONS_DIR"], "resistor.svg"), String)
164-
@test get_gui_metadata(rc.ground).layout ==
165-
read(abspath(ENV["MTK_ICONS_DIR"], "ground.svg"), String)
166-
@test get_gui_metadata(rc.capacitor).layout ==
167-
URI("https://upload.wikimedia.org/wikipedia/commons/7/78/Capacitor_symbol.svg")
168-
@test OnePort.structure[:icon] ==
169-
URI("file:///" * abspath(ENV["MTK_ICONS_DIR"], "oneport.png"))
170-
@test ModelingToolkit.get_gui_metadata(rc.resistor.p).layout == Pin.structure[:icon] ==
171-
URI("file:///" * abspath(ENV["MTK_ICONS_DIR"], "pin.png"))
172-
173-
@test length(equations(rc)) == 1
137+
@equations begin
138+
connect(constant.output, source.V)
139+
connect(source.p, resistor.p)
140+
connect(resistor.n, capacitor.p)
141+
connect(capacitor.n, source.n, ground.g)
142+
end
174143
end
175144

145+
C_val = 20
146+
R_val = 20
147+
res__R = 100
148+
@mtkbuild rc = RC(; C_val, R_val, resistor.R = res__R)
149+
resistor = getproperty(rc, :resistor; namespace = false)
150+
@test getname(rc.resistor) === getname(resistor)
151+
@test getname(rc.resistor.R) === getname(resistor.R)
152+
@test getname(rc.resistor.v) === getname(resistor.v)
153+
# Test that `resistor.R` overrides `R_val` in the argument.
154+
@test getdefault(rc.resistor.R) == res__R != R_val
155+
# Test that `C_val` passed via argument is set as default of C.
156+
@test getdefault(rc.capacitor.C) == C_val
157+
# Test that `k`'s default value is unchanged.
158+
@test getdefault(rc.constant.k) == RC.structure[:kwargs][:k_val]
159+
@test getdefault(rc.capacitor.v) == 0.0
160+
161+
@test get_gui_metadata(rc.resistor).layout == Resistor.structure[:icon] ==
162+
read(joinpath(ENV["MTK_ICONS_DIR"], "resistor.svg"), String)
163+
@test get_gui_metadata(rc.ground).layout ==
164+
read(abspath(ENV["MTK_ICONS_DIR"], "ground.svg"), String)
165+
@test get_gui_metadata(rc.capacitor).layout ==
166+
URI("https://upload.wikimedia.org/wikipedia/commons/7/78/Capacitor_symbol.svg")
167+
@test OnePort.structure[:icon] ==
168+
URI("file:///" * abspath(ENV["MTK_ICONS_DIR"], "oneport.png"))
169+
@test ModelingToolkit.get_gui_metadata(rc.resistor.p).layout == Pin.structure[:icon] ==
170+
URI("file:///" * abspath(ENV["MTK_ICONS_DIR"], "pin.png"))
171+
172+
@test length(equations(rc)) == 1
173+
176174
@testset "Parameters and Structural parameters in various modes" begin
177175
@mtkmodel MockModel begin
178176
@parameters begin

0 commit comments

Comments
 (0)