Skip to content

Commit a864541

Browse files
Merge pull request #2196 from ven-k/vkb/kw
Pass arguments (including hierarchal ones) to `Model`s
2 parents 50cf8f9 + fb7d334 commit a864541

File tree

4 files changed

+227
-51
lines changed

4 files changed

+227
-51
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ import SymbolicUtils.Code: toexpr
4949
import SymbolicUtils.Rewriters: Chain, Postwalk, Prewalk, Fixpoint
5050
import JuliaFormatter
5151

52+
using MLStyle
53+
5254
using Reexport
5355
using Symbolics: degree
5456
@reexport using Symbolics

src/systems/abstractsystem.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,32 @@ function split_assign(expr)
935935
name, call = expr.args
936936
end
937937

938+
varname_fix!(s) = return
939+
940+
function varname_fix!(expr::Expr)
941+
for arg in expr.args
942+
MLStyle.@match arg begin
943+
::Symbol => continue
944+
Expr(:kw, a) => varname_sanitization!(arg)
945+
Expr(:parameters, a...) => begin
946+
for _arg in arg.args
947+
varname_sanitization!(_arg)
948+
end
949+
end
950+
_ => @debug "skipping variable sanitization of $arg"
951+
end
952+
end
953+
end
954+
955+
varname_sanitization!(a) = return
956+
957+
function varname_sanitization!(expr::Expr)
958+
var_splits = split(string(expr.args[1]), ".")
959+
if length(var_splits) > 1
960+
expr.args[1] = Symbol(join(var_splits, "__"))
961+
end
962+
end
963+
938964
function _named(name, call, runtime = false)
939965
has_kw = false
940966
call isa Expr || throw(Meta.ParseError("The rhs must be an Expr. Got $call."))
@@ -948,6 +974,8 @@ function _named(name, call, runtime = false)
948974
end
949975
end
950976

977+
varname_fix!(call)
978+
951979
if !has_kw
952980
param = Expr(:parameters)
953981
if length(call.args) == 1

src/systems/model_parsing.jl

Lines changed: 144 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
1-
macro connector(name::Symbol, body)
2-
esc(connector_macro(__module__, name, body))
3-
end
4-
51
struct Model{F, S}
62
f::F
73
structure::S
84
end
95
(m::Model)(args...; kw...) = m.f(args...; kw...)
106

11-
using MLStyle
7+
for f in (:connector, :model)
8+
@eval begin
9+
macro $f(name::Symbol, body)
10+
esc($(Symbol(f, :_macro))(__module__, name, body))
11+
end
12+
end
13+
end
14+
15+
@inline is_kwarg(::Symbol) = false
16+
@inline is_kwarg(e::Expr) = (e.head == :parameters)
1217

13-
function connector_macro(mod, name, body)
18+
function connector_macro(mod, name, body; arglist = Set([]), kwargs = Set([]))
1419
if !Meta.isexpr(body, :block)
1520
err = """
1621
connector body must be a block! It should be in the form of
@@ -23,16 +28,18 @@ function connector_macro(mod, name, body)
2328
"""
2429
error(err)
2530
end
26-
vs = Num[]
31+
vs = []
2732
icon = Ref{Union{String, URI}}()
2833
dict = Dict{Symbol, Any}()
34+
dict[:kwargs] = Dict{Symbol, Any}()
35+
expr = Expr(:block)
2936
for arg in body.args
3037
arg isa LineNumberNode && continue
3138
if arg.head == :macrocall && arg.args[1] == Symbol("@icon")
3239
parse_icon!(icon, dict, dict, arg.args[end])
3340
continue
3441
end
35-
push!(vs, Num(parse_variable_def!(dict, mod, arg, :variables)))
42+
parse_variable_arg!(expr, vs, dict, mod, arg, :variables, kwargs)
3643
end
3744
iv = get(dict, :independent_variable, nothing)
3845
if iv === nothing
@@ -41,31 +48,50 @@ function connector_macro(mod, name, body)
4148
gui_metadata = isassigned(icon) ? GUIMetadata(GlobalRef(mod, name), icon[]) :
4249
nothing
4350
quote
44-
$name = $Model((; name) -> begin
45-
var"#___sys___" = $ODESystem($(Equation[]), $iv, $vs, $([]);
51+
$name = $Model(($(arglist...); name, $(kwargs...)) -> begin
52+
$expr
53+
var"#___sys___" = $ODESystem($(Equation[]), $iv, [$(vs...)], $([]);
4654
name, gui_metadata = $gui_metadata)
4755
$Setfield.@set!(var"#___sys___".connector_type=$connector_type(var"#___sys___"))
4856
end, $dict)
4957
end
5058
end
5159

52-
function parse_variable_def!(dict, mod, arg, varclass)
60+
function parse_variable_def!(dict, mod, arg, varclass, kwargs, def = nothing)
61+
arg isa LineNumberNode && return
5362
MLStyle.@match arg begin
54-
::Symbol => generate_var!(dict, arg, varclass)
55-
Expr(:call, a, b) => generate_var!(dict, a, b, varclass)
63+
a::Symbol => begin
64+
push!(kwargs, Expr(:kw, a, nothing))
65+
var = generate_var!(dict, a, varclass)
66+
dict[:kwargs][getname(var)] = def
67+
(var, nothing)
68+
end
69+
Expr(:call, a, b) => begin
70+
push!(kwargs, Expr(:kw, a, nothing))
71+
var = generate_var!(dict, a, b, varclass)
72+
dict[:kwargs][getname(var)] = def
73+
(var, nothing)
74+
end
5675
Expr(:(=), a, b) => begin
57-
var = parse_variable_def!(dict, mod, a, varclass)
58-
def = parse_default(mod, b)
76+
Base.remove_linenums!(b)
77+
def, meta = parse_default(mod, b)
78+
var, _ = parse_variable_def!(dict, mod, a, varclass, kwargs, def)
5979
dict[varclass][getname(var)][:default] = def
60-
setdefault(var, def)
80+
if !isnothing(meta)
81+
if (ct = get(meta, VariableConnectType, nothing)) !== nothing
82+
dict[varclass][getname(var)][:connection_type] = nameof(ct)
83+
end
84+
var = set_var_metadata(var, meta)
85+
end
86+
(var, def)
6187
end
6288
Expr(:tuple, a, b) => begin
63-
var = parse_variable_def!(dict, mod, a, varclass)
89+
var, def = parse_variable_def!(dict, mod, a, varclass, kwargs)
6490
meta = parse_metadata(mod, b)
6591
if (ct = get(meta, VariableConnectType, nothing)) !== nothing
6692
dict[varclass][getname(var)][:connection_type] = nameof(ct)
6793
end
68-
set_var_metadata(var, meta)
94+
(set_var_metadata(var, meta), def)
6995
end
7096
_ => error("$arg cannot be parsed")
7197
end
@@ -78,14 +104,17 @@ function generate_var(a, varclass)
78104
end
79105
var
80106
end
107+
81108
function generate_var!(dict, a, varclass)
109+
#var = generate_var(Symbol("#", a), varclass)
82110
var = generate_var(a, varclass)
83111
vd = get!(dict, varclass) do
84112
Dict{Symbol, Dict{Symbol, Any}}()
85113
end
86114
vd[a] = Dict{Symbol, Any}()
87115
var
88116
end
117+
89118
function generate_var!(dict, a, b, varclass)
90119
iv = generate_var(b, :variables)
91120
prev_iv = get!(dict, :independent_variable) do
@@ -102,77 +131,101 @@ function generate_var!(dict, a, b, varclass)
102131
end
103132
var
104133
end
134+
105135
function parse_default(mod, a)
106136
a = Base.remove_linenums!(deepcopy(a))
107137
MLStyle.@match a begin
108-
Expr(:block, a) => get_var(mod, a)
109-
::Symbol => get_var(mod, a)
110-
::Number => a
138+
Expr(:block, x) => parse_default(mod, x)
139+
Expr(:tuple, x, y) => begin
140+
def, _ = parse_default(mod, x)
141+
meta = parse_metadata(mod, y)
142+
(def, meta)
143+
end
144+
::Symbol || ::Number => (a, nothing)
145+
Expr(:call, a...) => begin
146+
def = parse_default.(Ref(mod), a)
147+
expr = Expr(:call)
148+
for (d, _) in def
149+
push!(expr.args, d)
150+
end
151+
(expr, nothing)
152+
end
111153
_ => error("Cannot parse default $a")
112154
end
113155
end
156+
114157
function parse_metadata(mod, a)
115158
MLStyle.@match a begin
116159
Expr(:vect, eles...) => Dict(parse_metadata(mod, e) for e in eles)
117160
Expr(:(=), a, b) => Symbolics.option_to_metadata_type(Val(a)) => get_var(mod, b)
118161
_ => error("Cannot parse metadata $a")
119162
end
120163
end
164+
121165
function set_var_metadata(a, ms)
122166
for (m, v) in ms
123167
a = setmetadata(a, m, v)
124168
end
125169
a
126170
end
171+
127172
function get_var(mod::Module, b)
128173
b isa Symbol ? getproperty(mod, b) : b
129174
end
130175

131-
macro model(name::Symbol, expr)
132-
esc(model_macro(__module__, name, expr))
133-
end
134-
135-
function model_macro(mod, name, expr)
176+
function model_macro(mod, name, expr; arglist = Set([]), kwargs = Set([]))
136177
exprs = Expr(:block)
137178
dict = Dict{Symbol, Any}()
179+
dict[:kwargs] = Dict{Symbol, Any}()
138180
comps = Symbol[]
139181
ext = Ref{Any}(nothing)
140-
vs = Symbol[]
141-
ps = Symbol[]
142182
eqs = Expr[]
143183
icon = Ref{Union{String, URI}}()
184+
vs = []
185+
ps = []
186+
144187
for arg in expr.args
145188
arg isa LineNumberNode && continue
146-
arg.head == :macrocall || error("$arg is not valid syntax. Expected a macro call.")
147-
parse_model!(exprs.args, comps, ext, eqs, vs, ps, icon, dict, mod, arg)
189+
if arg.head == :macrocall
190+
parse_model!(exprs.args, comps, ext, eqs, icon, vs, ps,
191+
dict, mod, arg, kwargs)
192+
elseif arg.head == :block
193+
push!(exprs.args, arg)
194+
else
195+
error("$arg is not valid syntax. Expected a macro call.")
196+
end
148197
end
149198
iv = get(dict, :independent_variable, nothing)
150199
if iv === nothing
151200
iv = dict[:independent_variable] = variable(:t)
152201
end
202+
153203
gui_metadata = isassigned(icon) > 0 ? GUIMetadata(GlobalRef(mod, name), icon[]) :
154204
nothing
205+
155206
sys = :($ODESystem($Equation[$(eqs...)], $iv, [$(vs...)], [$(ps...)];
156-
systems = [$(comps...)], name, gui_metadata = $gui_metadata))
207+
systems = [$(comps...)], name, gui_metadata = $gui_metadata)) #, defaults = $defaults))
157208
if ext[] === nothing
158209
push!(exprs.args, sys)
159210
else
160211
push!(exprs.args, :($extend($sys, $(ext[]))))
161212
end
162-
:($name = $Model((; name) -> $exprs, $dict))
213+
214+
:($name = $Model(($(arglist...); name, $(kwargs...)) -> $exprs, $dict))
163215
end
164216

165-
function parse_model!(exprs, comps, ext, eqs, vs, ps, icon, dict, mod, arg)
217+
function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, dict,
218+
mod, arg, kwargs)
166219
mname = arg.args[1]
167220
body = arg.args[end]
168221
if mname == Symbol("@components")
169-
parse_components!(exprs, comps, dict, body)
222+
parse_components!(exprs, comps, dict, body, kwargs)
170223
elseif mname == Symbol("@extend")
171224
parse_extend!(exprs, ext, dict, body)
172225
elseif mname == Symbol("@variables")
173-
parse_variables!(exprs, vs, dict, mod, body, :variables)
226+
parse_variables!(exprs, vs, dict, mod, body, :variables, kwargs)
174227
elseif mname == Symbol("@parameters")
175-
parse_variables!(exprs, ps, dict, mod, body, :parameters)
228+
parse_variables!(exprs, ps, dict, mod, body, :parameters, kwargs)
176229
elseif mname == Symbol("@equations")
177230
parse_equations!(exprs, eqs, dict, body)
178231
elseif mname == Symbol("@icon")
@@ -182,7 +235,7 @@ function parse_model!(exprs, comps, ext, eqs, vs, ps, icon, dict, mod, arg)
182235
end
183236
end
184237

185-
function parse_components!(exprs, cs, dict, body)
238+
function parse_components!(exprs, cs, dict, body, kwargs)
186239
expr = Expr(:block)
187240
push!(exprs, expr)
188241
comps = Vector{String}[]
@@ -194,6 +247,9 @@ function parse_components!(exprs, cs, dict, body)
194247
push!(comps, [String(a), String(b.args[1])])
195248
arg = deepcopy(arg)
196249
b = deepcopy(arg.args[2])
250+
251+
component_args!(a, b, expr, kwargs)
252+
197253
push!(b.args, Expr(:kw, :name, Meta.quot(a)))
198254
arg.args[2] = b
199255
push!(expr.args, arg)
@@ -204,6 +260,46 @@ function parse_components!(exprs, cs, dict, body)
204260
dict[:components] = comps
205261
end
206262

263+
function _rename(compname, varname)
264+
compname = Symbol(compname, :__, varname)
265+
end
266+
267+
function component_args!(a, b, expr, kwargs)
268+
# Whenever `b` is a function call, skip the first arg aka the function name.
269+
# Whenver it is a kwargs list, include it.
270+
start = b.head == :call ? 2 : 1
271+
for i in start:lastindex(b.args)
272+
arg = b.args[i]
273+
arg isa LineNumberNode && continue
274+
MLStyle.@match arg begin
275+
::Symbol => begin
276+
_v = _rename(a, arg)
277+
push!(kwargs, _v)
278+
b.args[i] = Expr(:kw, arg, _v)
279+
end
280+
Expr(:parameters, x...) => begin
281+
component_args!(a, arg, expr, kwargs)
282+
end
283+
Expr(:kw, x) => begin
284+
_v = _rename(a, x)
285+
b.args[i] = Expr(:kw, x, _v)
286+
push!(kwargs, _v)
287+
end
288+
Expr(:kw, x, y::Number) => begin
289+
_v = _rename(a, x)
290+
b.args[i] = Expr(:kw, x, _v)
291+
push!(kwargs, Expr(:kw, _v, y))
292+
end
293+
Expr(:kw, x, y) => begin
294+
_v = _rename(a, x)
295+
push!(expr.args, :($y = $_v))
296+
push!(kwargs, Expr(:kw, _v, y))
297+
end
298+
_ => error("Could not parse $arg of component $a")
299+
end
300+
end
301+
end
302+
207303
function parse_extend!(exprs, ext, dict, body)
208304
expr = Expr(:block)
209305
push!(exprs, expr)
@@ -231,16 +327,21 @@ function parse_extend!(exprs, ext, dict, body)
231327
end
232328
end
233329

234-
function parse_variables!(exprs, vs, dict, mod, body, varclass)
330+
function parse_variable_arg!(expr, vs, dict, mod, arg, varclass, kwargs)
331+
vv, def = parse_variable_def!(dict, mod, arg, varclass, kwargs)
332+
v = Num(vv)
333+
name = getname(v)
334+
push!(vs, name)
335+
push!(expr.args,
336+
:($name = $name === nothing ? $setdefault($vv, $def) : $setdefault($vv, $name)))
337+
end
338+
339+
function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs)
235340
expr = Expr(:block)
236341
push!(exprs, expr)
237342
for arg in body.args
238343
arg isa LineNumberNode && continue
239-
vv = parse_variable_def!(dict, mod, arg, varclass)
240-
v = Num(vv)
241-
name = getname(v)
242-
push!(vs, name)
243-
push!(expr.args, :($name = $v))
344+
parse_variable_arg!(expr, vs, dict, mod, arg, varclass, kwargs)
244345
end
245346
end
246347

0 commit comments

Comments
 (0)