Skip to content

Commit 2969303

Browse files
committed
Add extend parsing
1 parent ff71696 commit 2969303

File tree

1 file changed

+47
-8
lines changed

1 file changed

+47
-8
lines changed

src/systems/connectors.jl

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@ function connector_macro(mod, name, body)
4545
error("$name doesn't have a independent variable")
4646
end
4747
quote
48-
$name = $Model((; name) -> $ODESystem($(Equation[]), $iv, $vs, $([]); name), $dict)
48+
$name = $Model((; name) -> begin
49+
var"#___sys___" = $ODESystem($(Equation[]), $iv, $vs, $([]);
50+
name)
51+
$Setfield.@set!(var"#___sys___".connector_type=$connector_type(var"#___sys___"))
52+
end, $dict)
4953
end
5054
end
5155

@@ -79,7 +83,7 @@ function generate_var(a, varclass)
7983
var
8084
end
8185
function generate_var!(dict, a, varclass)
82-
var = generate_var(a, :variables)
86+
var = generate_var(a, varclass)
8387
vd = get!(dict, varclass) do
8488
Dict{Symbol, Dict{Symbol, Any}}()
8589
end
@@ -116,6 +120,7 @@ function parse_default(mod, a)
116120
MLStyle.@match a begin
117121
Expr(:block, a) => get_var(mod, a)
118122
::Symbol => get_var(mod, a)
123+
::Number => a
119124
_ => error("Cannot parse default $a")
120125
end
121126
end
@@ -142,28 +147,35 @@ function model_macro(mod, name, expr)
142147
exprs = Expr(:block)
143148
dict = Dict{Symbol, Any}()
144149
comps = Symbol[]
150+
ext = Ref{Any}(nothing)
145151
vs = Symbol[]
146152
ps = Symbol[]
147153
eqs = Expr[]
148154
for arg in expr.args
149155
arg isa LineNumberNode && continue
150156
arg.head == :macrocall || error("$arg is not valid syntax. Expected a macro call.")
151-
parse_model!(exprs.args, comps, eqs, vs, ps, dict, mod, arg)
157+
parse_model!(exprs.args, comps, ext, eqs, vs, ps, dict, mod, arg)
152158
end
153159
iv = get(dict, :independent_variable, nothing)
154160
if iv === nothing
155161
iv = dict[:independent_variable] = variable(:t)
156162
end
157-
push!(exprs.args,
158-
:($ODESystem($Equation[$(eqs...)], $iv, [$(vs...)], [$(ps...)];
159-
systems = [$(comps...)], name)))
163+
sys = :($ODESystem($Equation[$(eqs...)], $iv, [$(vs...)], [$(ps...)];
164+
systems = [$(comps...)], name))
165+
if ext[] === nothing
166+
push!(exprs.args, sys)
167+
else
168+
push!(exprs.args, :($extend($sys, $(ext[]))))
169+
end
160170
:($name = $Model((; name) -> $exprs, $dict))
161171
end
162-
function parse_model!(exprs, comps, eqs, vs, ps, dict, mod, arg)
172+
function parse_model!(exprs, comps, ext, eqs, vs, ps, dict, mod, arg)
163173
mname = arg.args[1]
164174
body = arg.args[end]
165175
if mname == Symbol("@components")
166176
parse_components!(exprs, comps, dict, body)
177+
elseif mname == Symbol("@extend")
178+
parse_extend!(exprs, ext, dict, body)
167179
elseif mname == Symbol("@variables")
168180
parse_variables!(exprs, vs, dict, mod, body, :variables)
169181
elseif mname == Symbol("@parameters")
@@ -195,12 +207,39 @@ function parse_components!(exprs, cs, dict, body)
195207
end
196208
dict[:components] = comps
197209
end
210+
function parse_extend!(exprs, ext, dict, body)
211+
expr = Expr(:block)
212+
push!(exprs, expr)
213+
body = deepcopy(body)
214+
MLStyle.@match body begin
215+
Expr(:(=), a, b) => begin
216+
vars = nothing
217+
if Meta.isexpr(b, :(=))
218+
vars = a
219+
if !Meta.isexpr(vars, :tuple)
220+
error("`@extend` destructuring only takes an tuple as LHS. Got $body")
221+
end
222+
a, b = b.args
223+
vars, a, b
224+
end
225+
ext[] = a
226+
push!(b.args, Expr(:kw, :name, Meta.quot(a)))
227+
dict[:extend] = Symbol.(vars.args), a, readable_code(b)
228+
push!(expr.args, :($a = $b))
229+
if vars !== nothing
230+
push!(expr.args, :(@unpack $vars = $a))
231+
end
232+
end
233+
_ => error("`@extend` only takes an assignment expression. Got $body")
234+
end
235+
end
198236
function parse_variables!(exprs, vs, dict, mod, body, varclass)
199237
expr = Expr(:block)
200238
push!(exprs, expr)
201239
for arg in body.args
202240
arg isa LineNumberNode && continue
203-
v = Num(parse_variable_def!(dict, mod, arg, varclass))
241+
vv = parse_variable_def!(dict, mod, arg, varclass)
242+
v = Num(vv)
204243
name = getname(v)
205244
push!(vs, name)
206245
push!(expr.args, :($name = $v))

0 commit comments

Comments
 (0)