Skip to content

Commit b11a497

Browse files
committed
feat: handle hierarchal kwargs of the Model
`@named model_a = ModelA( model_b.component=1)` just works
1 parent 6ab41e7 commit b11a497

File tree

3 files changed

+87
-8
lines changed

3 files changed

+87
-8
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ import SymbolicUtils.Code: toexpr
4949
import SymbolicUtils.Rewriters: Chain, Postwalk, Prewalk, Fixpoint
5050
import JuliaFormatter
5151

52+
using MLStyle
53+
5254
using Reexport
5355
using Symbolics: degree
5456
@reexport using Symbolics

src/systems/abstractsystem.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,32 @@ function split_assign(expr)
935935
name, call = expr.args
936936
end
937937

938+
varname_fix!(s) = return
939+
940+
function varname_fix!(expr::Expr)
941+
for arg in expr.args
942+
MLStyle.@match arg begin
943+
::Symbol => continue
944+
Expr(:kw, a) => varname_sanitization!(arg)
945+
Expr(:parameters, a...) => begin
946+
for _arg in arg.args
947+
varname_sanitization!(_arg)
948+
end
949+
end
950+
_ => @debug "skipping variable sanitization of $arg"
951+
end
952+
end
953+
end
954+
955+
varname_sanitization!(a) = return
956+
957+
function varname_sanitization!(expr::Expr)
958+
var_splits = split(string(expr.args[1]), ".")
959+
if length(var_splits) > 1
960+
expr.args[1] = Symbol(join(var_splits, "__"))
961+
end
962+
end
963+
938964
function _named(name, call, runtime = false)
939965
has_kw = false
940966
call isa Expr || throw(Meta.ParseError("The rhs must be an Expr. Got $call."))
@@ -948,6 +974,8 @@ function _named(name, call, runtime = false)
948974
end
949975
end
950976

977+
varname_fix!(call)
978+
951979
if !has_kw
952980
param = Expr(:parameters)
953981
if length(call.args) == 1

src/systems/model_parsing.jl

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ struct Model{F, S}
88
end
99
(m::Model)(args...; kw...) = m.f(args...; kw...)
1010

11-
using MLStyle
12-
1311
function connector_macro(mod, name, body)
1412
if !Meta.isexpr(body, :block)
1513
err = """
@@ -82,10 +80,10 @@ end
8280
# methods)
8381
function parse_variables_with_kw!(exprs, dict, mod, body, varclass, kwargs)
8482
expr = if varclass == :parameters
85-
:(ps = @parameters begin
83+
:(pss = @parameters begin
8684
end)
8785
elseif varclass == :variables
88-
:(vs = @variables begin
86+
:(vss = @variables begin
8987
end)
9088
end
9189

@@ -97,9 +95,12 @@ function parse_variables_with_kw!(exprs, dict, mod, body, varclass, kwargs)
9795
def = Base.remove_linenums!(b).args[end]
9896
push!(expr.args[end].args[end].args, :($a = $def))
9997
push!(kwargs, def)
98+
@info "\nIn $varclass $kwargs for arg: $arg"
10099
end
100+
_ => "got $arg"
101101
end
102102
end
103+
dict[:kwargs] = kwargs
103104
push!(exprs, expr)
104105
end
105106

@@ -173,6 +174,8 @@ function model_macro(mod, name, expr)
173174
eqs = Expr[]
174175
icon = Ref{Union{String, URI}}()
175176
kwargs = []
177+
vs, vss = [], []
178+
ps, pss = [], []
176179
for arg in expr.args
177180
arg isa LineNumberNode && continue
178181
arg.head == :macrocall || error("$arg is not valid syntax. Expected a macro call.")
@@ -184,22 +187,22 @@ function model_macro(mod, name, expr)
184187
end
185188
gui_metadata = isassigned(icon) > 0 ? GUIMetadata(GlobalRef(mod, name), icon[]) :
186189
nothing
187-
sys = :($ODESystem($Equation[$(eqs...)], $iv, vs, ps;
190+
sys = :($ODESystem($Equation[$(eqs...)], $iv, [], [$(ps...); pss...];
188191
systems = [$(comps...)], name, gui_metadata = $gui_metadata))
189192
if ext[] === nothing
190193
push!(exprs.args, sys)
191194
else
192195
push!(exprs.args, :($extend($sys, $(ext[]))))
193196
end
194-
197+
@info "\nexprs $exprs final kwargs: $kwargs"
195198
:($name = $Model((; name, $(kwargs...)) -> $exprs, $dict))
196199
end
197200

198201
function parse_model!(exprs, comps, ext, eqs, icon, dict, mod, arg, kwargs)
199202
mname = arg.args[1]
200203
body = arg.args[end]
201204
if mname == Symbol("@components")
202-
parse_components!(exprs, comps, dict, body)
205+
parse_components!(exprs, comps, dict, body, kwargs)
203206
elseif mname == Symbol("@extend")
204207
parse_extend!(exprs, ext, dict, body)
205208
elseif mname == Symbol("@variables")
@@ -215,18 +218,64 @@ function parse_model!(exprs, comps, ext, eqs, icon, dict, mod, arg, kwargs)
215218
end
216219
end
217220

218-
function parse_components!(exprs, cs, dict, body)
221+
function var_rename(compname, varname::Expr, arglist)
222+
@info typeof(varname)
223+
compname = Symbol(compname, :__, varname.args[1])
224+
push!(arglist, Expr(:(=), compname, varname.args[2]))
225+
@info "$(typeof(varname)) | the arglist @220 is $arglist"
226+
return Expr(:kw, varname, compname)
227+
end
228+
229+
function var_rename(compname, varname, arglist)
230+
compname = Symbol(compname, :__, varname)
231+
push!(arglist, :($compname))
232+
@info "$(typeof(varname)) | the arglist @229 is $arglist"
233+
return Expr(:kw, varname, compname)
234+
end
235+
236+
function component_args!(compname, comparg, arglist, varnamed)
237+
for arg in comparg.args
238+
arg isa LineNumberNode && continue
239+
MLStyle.@match arg begin
240+
Expr(:parameters, a, b) => begin
241+
component_args!(compname, arg, arglist, varnamed)
242+
end
243+
Expr(:parameters, Expr) => begin
244+
# push!(varnamed , var_rename(compname, a, arglist))
245+
push!(varnamed , var_rename.(Ref(compname), arg.args, Ref(arglist)))
246+
end
247+
Expr(:parameters, a) => begin
248+
# push!(varnamed , var_rename(compname, a, arglist))
249+
for a_arg in a.args
250+
push!(varnamed , var_rename(compname, a_arg, arglist))
251+
end
252+
end
253+
Expr(:kw, a, b) => begin
254+
push!(varnamed , var_rename(compname, a, arglist))
255+
end
256+
::Symbol => continue
257+
_ => @info "got $arg"
258+
end
259+
end
260+
end
261+
262+
function parse_components!(exprs, cs, dict, body, kwargs)
219263
expr = Expr(:block)
220264
push!(exprs, expr)
221265
comps = Vector{String}[]
266+
varnamed = []
222267
for arg in body.args
223268
arg isa LineNumberNode && continue
224269
MLStyle.@match arg begin
225270
Expr(:(=), a, b) => begin
226271
push!(cs, a)
272+
component_args!(a, b, kwargs, varnamed)
227273
push!(comps, [String(a), String(b.args[1])])
228274
arg = deepcopy(arg)
229275
b = deepcopy(arg.args[2])
276+
277+
b.args[2] = varnamed[1][1]
278+
230279
push!(b.args, Expr(:kw, :name, Meta.quot(a)))
231280
arg.args[2] = b
232281
push!(expr.args, arg)

0 commit comments

Comments
 (0)