Skip to content

Commit c6af68d

Browse files
authored
Parse the args in extended components (#2202)
1 parent bf3551e commit c6af68d

File tree

4 files changed

+59
-34
lines changed

4 files changed

+59
-34
lines changed

src/systems/abstractsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,7 @@ function varname_fix!(expr::Expr)
941941
for arg in expr.args
942942
MLStyle.@match arg begin
943943
::Symbol => continue
944-
Expr(:kw, a) => varname_sanitization!(arg)
944+
Expr(:kw, a...) || Expr(:kw, a) => varname_sanitization!(arg)
945945
Expr(:parameters, a...) => begin
946946
for _arg in arg.args
947947
varname_sanitization!(_arg)

src/systems/model_parsing.jl

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ end
1515
@inline is_kwarg(::Symbol) = false
1616
@inline is_kwarg(e::Expr) = (e.head == :parameters)
1717

18-
function connector_macro(mod, name, body; arglist = Set([]), kwargs = Set([]))
18+
function connector_macro(mod, name, body)
1919
if !Meta.isexpr(body, :block)
2020
err = """
2121
connector body must be a block! It should be in the form of
@@ -29,6 +29,7 @@ function connector_macro(mod, name, body; arglist = Set([]), kwargs = Set([]))
2929
error(err)
3030
end
3131
vs = []
32+
kwargs = []
3233
icon = Ref{Union{String, URI}}()
3334
dict = Dict{Symbol, Any}()
3435
dict[:kwargs] = Dict{Symbol, Any}()
@@ -48,7 +49,7 @@ function connector_macro(mod, name, body; arglist = Set([]), kwargs = Set([]))
4849
gui_metadata = isassigned(icon) ? GUIMetadata(GlobalRef(mod, name), icon[]) :
4950
nothing
5051
quote
51-
$name = $Model(($(arglist...); name, $(kwargs...)) -> begin
52+
$name = $Model((; name, $(kwargs...)) -> begin
5253
$expr
5354
var"#___sys___" = $ODESystem($(Equation[]), $iv, [$(vs...)], $([]);
5455
name, gui_metadata = $gui_metadata)
@@ -173,7 +174,7 @@ function get_var(mod::Module, b)
173174
b isa Symbol ? getproperty(mod, b) : b
174175
end
175176

176-
function mtkmodel_macro(mod, name, expr; arglist = Set([]), kwargs = Set([]))
177+
function mtkmodel_macro(mod, name, expr)
177178
exprs = Expr(:block)
178179
dict = Dict{Symbol, Any}()
179180
dict[:kwargs] = Dict{Symbol, Any}()
@@ -183,6 +184,7 @@ function mtkmodel_macro(mod, name, expr; arglist = Set([]), kwargs = Set([]))
183184
icon = Ref{Union{String, URI}}()
184185
vs = []
185186
ps = []
187+
kwargs = []
186188

187189
for arg in expr.args
188190
arg isa LineNumberNode && continue
@@ -211,7 +213,7 @@ function mtkmodel_macro(mod, name, expr; arglist = Set([]), kwargs = Set([]))
211213
push!(exprs.args, :($extend($sys, $(ext[]))))
212214
end
213215

214-
:($name = $Model(($(arglist...); name, $(kwargs...)) -> $exprs, $dict))
216+
:($name = $Model((; name, $(kwargs...)) -> $exprs, $dict))
215217
end
216218

217219
function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, dict,
@@ -221,7 +223,7 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, dict,
221223
if mname == Symbol("@components")
222224
parse_components!(exprs, comps, dict, body, kwargs)
223225
elseif mname == Symbol("@extend")
224-
parse_extend!(exprs, ext, dict, body)
226+
parse_extend!(exprs, ext, dict, body, kwargs)
225227
elseif mname == Symbol("@variables")
226228
parse_variables!(exprs, vs, dict, mod, body, :variables, kwargs)
227229
elseif mname == Symbol("@parameters")
@@ -272,35 +274,25 @@ function component_args!(a, b, expr, kwargs)
272274
arg = b.args[i]
273275
arg isa LineNumberNode && continue
274276
MLStyle.@match arg begin
275-
::Symbol => begin
276-
_v = _rename(a, arg)
277-
push!(kwargs, _v)
278-
b.args[i] = Expr(:kw, arg, _v)
279-
end
280-
Expr(:parameters, x...) => begin
281-
component_args!(a, arg, expr, kwargs)
282-
end
283-
Expr(:kw, x) => begin
277+
x::Symbol || Expr(:kw, x) => begin
284278
_v = _rename(a, x)
285279
b.args[i] = Expr(:kw, x, _v)
286-
push!(kwargs, _v)
280+
push!(kwargs, Expr(:kw, _v, nothing))
287281
end
288-
Expr(:kw, x, y::Number) => begin
289-
_v = _rename(a, x)
290-
b.args[i] = Expr(:kw, x, _v)
291-
push!(kwargs, Expr(:kw, _v, y))
282+
Expr(:parameters, x...) => begin
283+
component_args!(a, arg, expr, kwargs)
292284
end
293285
Expr(:kw, x, y) => begin
294286
_v = _rename(a, x)
295-
push!(expr.args, :($y = $_v))
287+
b.args[i] = Expr(:kw, x, _v)
296288
push!(kwargs, Expr(:kw, _v, y))
297289
end
298290
_ => error("Could not parse $arg of component $a")
299291
end
300292
end
301293
end
302294

303-
function parse_extend!(exprs, ext, dict, body)
295+
function parse_extend!(exprs, ext, dict, body, kwargs)
304296
expr = Expr(:block)
305297
push!(exprs, expr)
306298
body = deepcopy(body)
@@ -313,6 +305,7 @@ function parse_extend!(exprs, ext, dict, body)
313305
error("`@extend` destructuring only takes an tuple as LHS. Got $body")
314306
end
315307
a, b = b.args
308+
component_args!(a, b, expr, kwargs)
316309
vars, a, b
317310
end
318311
ext[] = a

test/jumpsystem.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
using ModelingToolkit, DiffEqBase, JumpProcesses, Test, LinearAlgebra
1+
using ModelingToolkit, DiffEqBase, JumpProcesses, Test, LinearAlgebra, StableRNGs
22
MT = ModelingToolkit
33

4+
rng = StableRNG(12345)
5+
46
# basic MT SIR model with tweaks
57
@parameters β γ t
68
@constants h = 1
@@ -63,7 +65,7 @@ tspan = (0.0, 250.0);
6365
u₀map = [S => 999, I => 1, R => 0]
6466
parammap ==> 0.1 / 1000, γ => 0.01]
6567
dprob = DiscreteProblem(js2, u₀map, tspan, parammap)
66-
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false))
68+
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng)
6769
Nsims = 30000
6870
function getmean(jprob, Nsims)
6971
m = 0.0
@@ -79,13 +81,13 @@ m = getmean(jprob, Nsims)
7981
obs = [S2 ~ 2 * S]
8082
@named js2b = JumpSystem([j₁, j₃], t, [S, I, R], [β, γ], observed = obs)
8183
dprob = DiscreteProblem(js2b, u₀map, tspan, parammap)
82-
jprob = JumpProblem(js2b, dprob, Direct(), save_positions = (false, false))
84+
jprob = JumpProblem(js2b, dprob, Direct(), save_positions = (false, false), rng = rng)
8385
sol = solve(jprob, SSAStepper(), saveat = tspan[2] / 10)
8486
@test all(2 .* sol[S] .== sol[S2])
8587

8688
# test save_positions is working
8789

88-
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false))
90+
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng)
8991
sol = solve(jprob, SSAStepper(), saveat = 1.0)
9092
@test all((sol.t) .== collect(0.0:tspan[2]))
9193

@@ -120,7 +122,7 @@ function a2!(integrator)
120122
end
121123
j2 = ConstantRateJump(r2, a2!)
122124
jset = JumpSet((), (j1, j2), nothing, nothing)
123-
jprob = JumpProblem(prob, Direct(), jset, save_positions = (false, false))
125+
jprob = JumpProblem(prob, Direct(), jset, save_positions = (false, false), rng = rng)
124126
m2 = getmean(jprob, Nsims)
125127

126128
# test JumpSystem solution agrees with direct version
@@ -131,16 +133,16 @@ maj1 = MassActionJump(2 * β / 2, [S => 1, I => 1], [S => -1, I => 1])
131133
maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1])
132134
@named js3 = JumpSystem([maj1, maj2], t, [S, I, R], [β, γ])
133135
dprob = DiscreteProblem(js3, u₀map, tspan, parammap)
134-
jprob = JumpProblem(js3, dprob, Direct())
136+
jprob = JumpProblem(js3, dprob, Direct(), rng = rng)
135137
m3 = getmean(jprob, Nsims)
136138
@test abs(m - m3) / m < 0.01
137139

138140
# maj jump test with various dep graphs
139141
@named js3b = JumpSystem([maj1, maj2], t, [S, I, R], [β, γ])
140-
jprobb = JumpProblem(js3b, dprob, NRM())
142+
jprobb = JumpProblem(js3b, dprob, NRM(), rng = rng)
141143
m4 = getmean(jprobb, Nsims)
142144
@test abs(m - m4) / m < 0.01
143-
jprobc = JumpProblem(js3b, dprob, RSSA())
145+
jprobc = JumpProblem(js3b, dprob, RSSA(), rng = rng)
144146
m4 = getmean(jprobc, Nsims)
145147
@test abs(m - m4) / m < 0.01
146148

@@ -149,7 +151,7 @@ maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
149151
maj2 = MassActionJump(γ, [S => 1], [S => -1])
150152
@named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ])
151153
dprob = DiscreteProblem(js4, [S => 999], (0, 1000.0), [β => 100.0, γ => 0.01])
152-
jprob = JumpProblem(js4, dprob, Direct())
154+
jprob = JumpProblem(js4, dprob, Direct(), rng = rng)
153155
m4 = getmean(jprob, Nsims)
154156
@test abs(m4 - 2.0 / 0.01) * 0.01 / 2.0 < 0.01
155157

@@ -158,7 +160,7 @@ maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
158160
maj2 = MassActionJump(γ, [S => 2], [S => -1])
159161
@named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ])
160162
dprob = DiscreteProblem(js4, [S => 999], (0, 1000.0), [β => 100.0, γ => 0.01])
161-
jprob = JumpProblem(js4, dprob, Direct())
163+
jprob = JumpProblem(js4, dprob, Direct(), rng = rng)
162164
sol = solve(jprob, SSAStepper());
163165

164166
# issue #819
@@ -179,7 +181,7 @@ p = [k1 => 2.0, k2 => 0.0, k3 => 0.5]
179181
u₀ = [A => 100, B => 0]
180182
tspan = (0.0, 2000.0)
181183
dprob = DiscreteProblem(js5, u₀, tspan, p)
182-
jprob = JumpProblem(js5, dprob, Direct(), save_positions = (false, false))
184+
jprob = JumpProblem(js5, dprob, Direct(), save_positions = (false, false), rng = rng)
183185
@test all(jprob.massaction_jump.scaled_rates .== [1.0, 0.0])
184186

185187
pcondit(u, t, integrator) = t == 1000.0

test/model_parsing.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,13 @@ l15 0" stroke="black" stroke-width="1" stroke-linejoin="bevel" fill="none"></pat
9191
end
9292

9393
@mtkmodel Capacitor begin
94-
@extend v, i = oneport = OnePort()
9594
@parameters begin
9695
C
9796
end
97+
@variables begin
98+
v = 0.0
99+
end
100+
@extend v, i = oneport = OnePort(; v = v)
98101
@icon "https://upload.wikimedia.org/wikipedia/commons/7/78/Capacitor_symbol.svg"
99102
@equations begin
100103
D(v) ~ i / C
@@ -182,3 +185,30 @@ model = complete(model)
182185
@test getdefault(model.i) == 4
183186
@test isequal(getdefault(model.j), model.jval)
184187
@test isequal(getdefault(model.k), model.kval)
188+
189+
@mtkmodel A begin
190+
@parameters begin
191+
p
192+
end
193+
@components begin
194+
b = B(i = p, j = 1 / p, k = 1)
195+
end
196+
end
197+
198+
@mtkmodel B begin
199+
@parameters begin
200+
i
201+
j
202+
k
203+
end
204+
end
205+
206+
@named a = A(p = 10)
207+
getdefault(a.b.i) == 10
208+
getdefault(a.b.j) == 0.1
209+
getdefault(a.b.k) == 1
210+
211+
@named a = A(p = 10, b.i = 20, b.j = 30, b.k = 40)
212+
getdefault(a.b.i) == 20
213+
getdefault(a.b.j) == 30
214+
getdefault(a.b.k) == 40

0 commit comments

Comments
 (0)