Skip to content

Commit afae563

Browse files
committed
feat: conditional components and equations at the top level of @mtkmodel
1 parent 4246a7c commit afae563

File tree

1 file changed

+140
-80
lines changed

1 file changed

+140
-80
lines changed

src/systems/model_parsing.jl

Lines changed: 140 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,16 @@ struct Model{F, S}
1616
(:structural_parameters) and equations (:equations).
1717
"""
1818
structure::S
19-
"""
20-
This flag is `true` when the Model is a connector and is `false` when it is
21-
a component
2219
"""
20+
This flag is `true` when the Model is a connector and is `false` when it is
21+
a component
22+
"""
2323
isconnector::Bool
2424
end
2525
(m::Model)(args...; kw...) = m.f(args...; kw...)
2626

2727
for f in (:connector, :mtkmodel)
28-
isconnector = f == :connector ? true : false
28+
isconnector = f == :connector ? true : false
2929
@eval begin
3030
macro $f(name::Symbol, body)
3131
esc($(:_model_macro)(__module__, name, body, $isconnector))
@@ -47,18 +47,58 @@ function _model_macro(mod, name, expr, isconnector)
4747
push!(exprs.args, :(systems = ODESystem[]))
4848
push!(exprs.args, :(equations = Equation[]))
4949

50+
Base.remove_linenums!(expr)
5051
for arg in expr.args
51-
arg isa LineNumberNode && continue
5252
if arg.head == :macrocall
5353
parse_model!(exprs.args, comps, ext, eqs, icon, vs, ps,
5454
sps, dict, mod, arg, kwargs)
5555
elseif arg.head == :block
5656
push!(exprs.args, arg)
57+
elseif arg.head == :if
58+
MLStyle.@match arg begin
59+
Expr(:if, condition, x) => begin
60+
component_blk, equations_blk, parameter_blk, variable_blk = parse_top_level_branch(condition,
61+
x.args)
62+
63+
component_blk !== nothing &&
64+
parse_components!(exprs.args,
65+
comps,
66+
dict,
67+
:(begin
68+
$component_blk
69+
end),
70+
kwargs)
71+
equations_blk !== nothing &&
72+
parse_equations!(exprs.args, eqs, dict, :(begin
73+
$equations_blk
74+
end))
75+
# parameter_blk !== nothing && parse_variables!(exprs.args, ps, dict, mod, :(begin $parameter_blk end), :parameters, kwargs)
76+
# variable_blk !== nothing && parse_variables!(exprs.args, ps, dict, mod, :(begin $variable_blk end), :variables, kwargs)
77+
end
78+
Expr(:if, condition, x, y) => begin
79+
component_blk, equations_blk, parameter_blk, variable_blk = parse_top_level_branch(condition,
80+
x.args,
81+
y)
82+
83+
component_blk !== nothing &&
84+
parse_components!(exprs.args,
85+
comps, dict, :(begin
86+
$component_blk
87+
end), kwargs)
88+
equations_blk !== nothing &&
89+
parse_equations!(exprs.args, eqs, dict, :(begin
90+
$equations_blk
91+
end))
92+
# parameter_blk !== nothing && parse_variables!(exprs.args, ps, dict, mod, :(begin $parameter_blk end), :parameters, kwargs)
93+
# variable_blk !== nothing && parse_variables!(exprs.args, ps, dict, mod, :(begin $variable_blk end), :variables, kwargs)
94+
end
95+
_ => error("Got an invalid argument: $arg")
96+
end
5797
elseif isconnector
5898
# Connectors can have variables listed without `@variables` prefix or
5999
# begin block.
60100
parse_variable_arg!(exprs, vs, dict, mod, arg, :variables, kwargs)
61-
else
101+
else
62102
error("$arg is not valid syntax. Expected a macro call.")
63103
end
64104
end
@@ -75,7 +115,7 @@ function _model_macro(mod, name, expr, isconnector)
75115
GUIMetadata(GlobalRef(mod, name))
76116

77117
sys = :($ODESystem($Equation[equations...], $iv, [$(vs...)], [$(ps...)];
78-
name, systems, gui_metadata = $gui_metadata))
118+
name, systems, gui_metadata = $gui_metadata))
79119

80120
if ext[] === nothing
81121
push!(exprs.args, :(var"#___sys___" = $sys))
@@ -220,33 +260,11 @@ function parse_default(mod, a)
220260
end
221261
(expr, nothing)
222262
end
223-
#=Expr(:if, condition::Expr, x, y) => begin
224-
@info 212
225-
if condition.args[1] in (:(==), :(<), :(>))
226-
op = compare_op(condition.args[1])
227-
expr = Expr(:call)
228-
push!(expr.args, op)
229-
for cond in condition.args[2:end]
230-
# cond isa Symbol ? push!(expr.args, :($getdefault($cond))) :
231-
push!(expr.args, cond)
232-
end
233-
a.args[1] = expr
234-
end
235-
(a, nothing)
236-
end=#
237263
Expr(:if, condition, x, y) => (a, nothing)
238264
_ => error("Cannot parse default $a $(typeof(a))")
239265
end
240266
end
241267

242-
compare_op(a) = if a == :(==)
243-
:isequal
244-
elseif a == :(<)
245-
:isless
246-
elseif a == :(>)
247-
:(Base.isgreater)
248-
end
249-
250268
function parse_metadata(mod, a)
251269
MLStyle.@match a begin
252270
Expr(:vect, eles...) => Dict(parse_metadata(mod, e) for e in eles)
@@ -277,7 +295,7 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps,
277295
mname = arg.args[1]
278296
body = arg.args[end]
279297
if mname == Symbol("@components")
280-
parse_components!(mod, exprs, comps, dict, body, kwargs)
298+
parse_components!(exprs, comps, dict, body, kwargs)
281299
elseif mname == Symbol("@extend")
282300
parse_extend!(exprs, ext, dict, mod, body, kwargs)
283301
elseif mname == Symbol("@variables")
@@ -339,14 +357,6 @@ function extend_args!(a, b, dict, expr, kwargs, varexpr, has_param = false)
339357
dict[:kwargs][x] = nothing
340358
end
341359
Expr(:kw, x, y) => begin
342-
#= _v = _rename(a, x)
343-
push!(expr.args, :($_v = $y))
344-
def = Expr(:kw)
345-
push!(def.args, x)
346-
push!(def.args, :($getdefault($_v)))
347-
b.args[i] = def
348-
# b.args[i] = Expr(:kw, x, _v)
349-
push!(kwargs, Expr(:kw, _v, nothing))=#
350360
b.args[i] = Expr(:kw, x, x)
351361
push!(varexpr.args, :($x = $x === nothing ? $y : $x))
352362
push!(kwargs, Expr(:kw, x, nothing))
@@ -439,7 +449,7 @@ function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs)
439449
end
440450
end
441451

442-
function handle_if_x_equations!(ifexpr, condition, x, dict)
452+
function handle_if_x_equations!(ifexpr, condition, x)
443453
push!(ifexpr.args, condition, :(push!(equations, $(x.args...))))
444454
# push!(dict[:equations], [:if, readable_code(condition), readable_code.(x.args)])
445455
readable_code.(x.args)
@@ -449,8 +459,9 @@ function handle_if_y_equations!(ifexpr, y, dict)
449459
if y.head == :elseif
450460
elseifexpr = Expr(:elseif)
451461
eq_entry = [:elseif, readable_code.(y.args[1].args)...]
452-
push!(eq_entry, handle_if_x_equations!(elseifexpr, y.args[1], y.args[2], dict))
453-
get(y.args, 3, nothing) !== nothing && push!(eq_entry, handle_if_y_equations!(elseifexpr, y.args[3], dict))
462+
push!(eq_entry, handle_if_x_equations!(elseifexpr, y.args[1], y.args[2]))
463+
get(y.args, 3, nothing) !== nothing &&
464+
push!(eq_entry, handle_if_y_equations!(elseifexpr, y.args[3], dict))
454465
push!(ifexpr.args, elseifexpr)
455466
(eq_entry...,)
456467
else
@@ -466,17 +477,16 @@ function parse_equations!(exprs, eqs, dict, body)
466477
MLStyle.@match arg begin
467478
Expr(:if, condition, x) => begin
468479
ifexpr = Expr(:if)
469-
eq_entry = handle_if_x_equations!(ifexpr, condition, x, dict)
480+
eq_entry = handle_if_x_equations!(ifexpr, condition, x)
470481
push!(exprs, ifexpr)
471482
push!(dict[:equations], [:if, condition, eq_entry])
472483
end
473484
Expr(:if, condition, x, y) => begin
474485
ifexpr = Expr(:if)
475-
xeq_entry = handle_if_x_equations!(ifexpr, condition, x, dict)
486+
xeq_entry = handle_if_x_equations!(ifexpr, condition, x)
476487
yeq_entry = handle_if_y_equations!(ifexpr, y, dict)
477488
push!(exprs, ifexpr)
478489
push!(dict[:equations], [:if, condition, xeq_entry, yeq_entry])
479-
# push!(dict[:equations], yeq_entry...)
480490
end
481491
_ => push!(eqs, arg)
482492
end
@@ -523,7 +533,7 @@ function component_args!(a, b, expr, varexpr, kwargs)
523533
b.args[i] = Expr(:kw, x, _v)
524534
push!(varexpr.args, :((@isdefined $x) && ($_v = $x)))
525535
push!(kwargs, Expr(:kw, _v, nothing))
526-
# dict[:kwargs][_v] = nothing
536+
# dict[:kwargs][_v] = nothing
527537
end
528538
Expr(:parameters, x...) => begin
529539
component_args!(a, arg, expr, varexpr, kwargs)
@@ -550,20 +560,21 @@ function _parse_components!(exprs, body, kwargs)
550560
for arg in body.args
551561
arg isa LineNumberNode && continue
552562
MLStyle.@match arg begin
563+
Expr(:block) => begin
564+
# TODO: Do we need this?
565+
error("Multiple `@components` block detected within a single block")
566+
end
553567
Expr(:(=), a, b) => begin
554568
arg = deepcopy(arg)
555569
b = deepcopy(arg.args[2])
556570

557571
component_args!(a, b, expr, varexpr, kwargs)
558572

559-
# push!(b.args, Expr(:kw, :name, Meta.quot(a)))
560-
# arg.args[2] = b
561-
562573
push!(expr.args, arg)
563574
push!(comp_names, a)
564575
push!(comps, [a, b.args[1]])
565576
end
566-
_ => @info "Couldn't parse the component body: $arg"
577+
_ => error("Couldn't parse the component body: $arg")
567578
end
568579
end
569580
return comp_names, comps, expr, varexpr
@@ -572,24 +583,14 @@ end
572583
function push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
573584
blk = Expr(:block)
574585
push!(blk.args, varexpr)
575-
push!(blk.args, :(@named begin $(expr_vec.args...) end))
586+
push!(blk.args, :(@named begin
587+
$(expr_vec.args...)
588+
end))
576589
push!(blk.args, :($push!(systems, $(comp_names...))))
577590
push!(ifexpr.args, blk)
578591
end
579592

580593
function handle_if_x!(mod, exprs, ifexpr, x, kwargs, condition = nothing)
581-
@info 576 condition typeof(condition)
582-
# push!(ifexpr.args, :($substitute_defaults($condition)))
583-
#= if condition isa Symbol
584-
@info 579 condition
585-
push!(ifexpr.args, :($getdefault($condition)))
586-
elseif condition isa Num
587-
push!(ifexpr.args, :($substitute_defaults($condition)))
588-
elseif condition isa Expr
589-
push!(ifexpr.args, morph_with_default!(condition))
590-
else
591-
@info "Don't know what to do with $(typeof(condition))"
592-
end =#
593594
push!(ifexpr.args, condition)
594595
comp_names, comps, expr_vec, varexpr = _parse_components!(ifexpr, x, kwargs)
595596
push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
@@ -602,48 +603,107 @@ function handle_if_y!(exprs, ifexpr, y, kwargs)
602603
comps = [:elseif, y.args[1]]
603604
elseifexpr = Expr(:elseif)
604605
push!(comps, handle_if_x!(mod, exprs, elseifexpr, y.args[2], kwargs, y.args[1]))
605-
get(y.args, 3, nothing) !== nothing && push!(comps, handle_if_y!(exprs, elseifexpr, y.args[3], kwargs))
606+
get(y.args, 3, nothing) !== nothing &&
607+
push!(comps, handle_if_y!(exprs, elseifexpr, y.args[3], kwargs))
606608
push!(ifexpr.args, elseifexpr)
607609
(comps...,)
608610
else
609-
comp_names, comps, expr_vec, varexpr, = _parse_components!(exprs, y, kwargs)
611+
comp_names, comps, expr_vec, varexpr = _parse_components!(exprs, y, kwargs)
610612
push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
611613
comps
612614
end
613615
end
614616

615-
function parse_components!(mod, exprs, cs, dict, compbody, kwargs)
617+
function handle_conditional_components(condition, dict, exprs, kwargs, x, y = nothing)
618+
ifexpr = Expr(:if)
619+
comps = handle_if_x!(mod, exprs, ifexpr, x, kwargs, condition)
620+
ycomps = y === nothing ? [] : handle_if_y!(exprs, ifexpr, y, kwargs)
621+
push!(exprs, ifexpr)
622+
push!(dict[:components], (:if, condition, comps, ycomps))
623+
end
624+
625+
function parse_components!(exprs, cs, dict, compbody, kwargs)
616626
dict[:components] = []
617627
Base.remove_linenums!(compbody)
618628
for arg in compbody.args
619629
MLStyle.@match arg begin
620630
Expr(:if, condition, x) => begin
621-
ifexpr = Expr(:if)
622-
comps = handle_if_x!(mod, exprs, ifexpr, x, kwargs, condition)
623-
push!(exprs, ifexpr)
624-
push!(dict[:components], (:if, condition, comps, []))
631+
handle_conditional_components(condition, dict, exprs, kwargs, x)
625632
end
626633
Expr(:if, condition, x, y) => begin
627-
ifexpr = Expr(:if)
628-
comps = handle_if_x!(mod, exprs, ifexpr, x, kwargs, condition)
629-
ycomps = handle_if_y!(exprs, ifexpr, y, kwargs)
630-
push!(exprs, ifexpr)
631-
push!(dict[:components], (:if, condition, comps, ycomps))
634+
handle_conditional_components(condition, dict, exprs, kwargs, x, y)
632635
end
633636
Expr(:(=), a, b) => begin
634-
comp_names, comps, expr_vec, varexpr = _parse_components!(exprs, :(begin $arg end), kwargs)
637+
comp_names, comps, expr_vec, varexpr = _parse_components!(exprs,
638+
:(begin
639+
$arg
640+
end),
641+
kwargs)
635642
push!(cs, comp_names...)
636643
push!(dict[:components], comps...)
637-
push!(exprs, varexpr, :(@named begin $(expr_vec.args...) end))
644+
push!(exprs, varexpr, :(@named begin
645+
$(expr_vec.args...)
646+
end))
638647
end
639-
_ => @info "410 Couldn't parse the component body $arg"
648+
_ => @info "410 Couldn't parse the component body $compbody" @__LINE__
640649
end
641650
end
642-
643-
### zzz
644-
# push!(exprs, :(@named $expr))
645651
end
646652

647653
function _rename(compname, varname)
648654
compname = Symbol(compname, :__, varname)
649655
end
656+
657+
# Handle top level branching
658+
push_something!(v, ::Nothing) = v
659+
push_something!(v, x) = push!(v, x)
660+
push_something!(v, x...) = push_something!.(Ref(v), x)
661+
662+
define_blocks(branch) = [Expr(branch), Expr(branch), Expr(branch), Expr(branch)]
663+
664+
function parse_top_level_branch(condition, x, y = nothing, branch = :if)
665+
blocks::Vector{Union{Expr, Nothing}} = component_blk, equations_blk, parameter_blk, variable_blk = define_blocks(branch)
666+
667+
for arg in x
668+
if arg.args[1] == Symbol("@components")
669+
push_something!(component_blk.args, condition, arg.args[end])
670+
elseif arg.args[1] == Symbol("@equations")
671+
push_something!(equations_blk.args, condition, arg.args[end])
672+
elseif arg.args[1] == Symbol("@variables")
673+
push_something!(variable_blk.args, condition, arg.args[end])
674+
elseif arg.args[1] == Symbol("@parameters")
675+
push_something!(parameter_blk.args, condition, arg.args[end])
676+
else
677+
error("$(arg.args[1]) isn't supported")
678+
end
679+
end
680+
681+
if y !== nothing
682+
yblocks = if y.head == :elseif
683+
parse_top_level_branch(y.args[1],
684+
y.args[2].args,
685+
lastindex(y.args) == 3 ? y.args[3] : nothing,
686+
:elseif)
687+
else
688+
yblocks = parse_top_level_branch(nothing, y.args, nothing, :block)
689+
690+
for i in 1:lastindex(yblocks)
691+
yblocks[i] !== nothing && (yblocks[i] = yblocks[i].args[end])
692+
end
693+
yblocks
694+
end
695+
for i in 1:lastindex(yblocks)
696+
if lastindex(blocks[i].args) == 1
697+
push_something!(blocks[i].args, Expr(:block), yblocks[i])
698+
else
699+
push_something!(blocks[i].args, yblocks[i])
700+
end
701+
end
702+
end
703+
704+
for i in 1:lastindex(blocks)
705+
isempty(blocks[i].args) && (blocks[i] = nothing)
706+
end
707+
708+
return blocks
709+
end

0 commit comments

Comments
 (0)