Skip to content

Commit 30dcc02

Browse files
committed
feat: add support for conditional parameters and variables
1 parent a86d257 commit 30dcc02

File tree

2 files changed

+195
-120
lines changed

2 files changed

+195
-120
lines changed

src/systems/model_parsing.jl

Lines changed: 125 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,17 @@ function _model_macro(mod, name, expr, isconnector)
3737
exprs = Expr(:block)
3838
dict = Dict{Symbol, Any}()
3939
dict[:kwargs] = Dict{Symbol, Any}()
40+
dict[:parameters] = Any[Dict{Symbol, Dict{Symbol, Any}}()]
41+
dict[:variables] = Any[Dict{Symbol, Dict{Symbol, Any}}()]
4042
comps = Symbol[]
4143
ext = Ref{Any}(nothing)
4244
eqs = Expr[]
4345
icon = Ref{Union{String, URI}}()
4446
ps, sps, vs, = [], [], []
4547
kwargs = Set()
4648

49+
push!(exprs.args, :(variables = []))
50+
push!(exprs.args, :(parameters = []))
4751
push!(exprs.args, :(systems = ODESystem[]))
4852
push!(exprs.args, :(equations = Equation[]))
4953

@@ -57,47 +61,19 @@ function _model_macro(mod, name, expr, isconnector)
5761
elseif arg.head == :if
5862
MLStyle.@match arg begin
5963
Expr(:if, condition, x) => begin
60-
component_blk, equations_blk, parameter_blk, variable_blk = parse_top_level_branch(condition,
61-
x.args)
62-
63-
component_blk !== nothing &&
64-
parse_components!(exprs.args,
65-
comps,
66-
dict,
67-
:(begin
68-
$component_blk
69-
end),
70-
kwargs)
71-
equations_blk !== nothing &&
72-
parse_equations!(exprs.args, eqs, dict, :(begin
73-
$equations_blk
74-
end))
75-
# parameter_blk !== nothing && parse_variables!(exprs.args, ps, dict, mod, :(begin $parameter_blk end), :parameters, kwargs)
76-
# variable_blk !== nothing && parse_variables!(exprs.args, ps, dict, mod, :(begin $variable_blk end), :variables, kwargs)
64+
parse_conditional_model_statements(comps, dict, eqs, exprs, kwargs,
65+
mod, ps, vs, parse_top_level_branch(condition, x.args)...)
7766
end
7867
Expr(:if, condition, x, y) => begin
79-
component_blk, equations_blk, parameter_blk, variable_blk = parse_top_level_branch(condition,
80-
x.args,
81-
y)
82-
83-
component_blk !== nothing &&
84-
parse_components!(exprs.args,
85-
comps, dict, :(begin
86-
$component_blk
87-
end), kwargs)
88-
equations_blk !== nothing &&
89-
parse_equations!(exprs.args, eqs, dict, :(begin
90-
$equations_blk
91-
end))
92-
# parameter_blk !== nothing && parse_variables!(exprs.args, ps, dict, mod, :(begin $parameter_blk end), :parameters, kwargs)
93-
# variable_blk !== nothing && parse_variables!(exprs.args, ps, dict, mod, :(begin $variable_blk end), :variables, kwargs)
68+
parse_conditional_model_statements(comps, dict, eqs, exprs, kwargs,
69+
mod, ps, vs, parse_top_level_branch(condition, x.args, y)...)
9470
end
9571
_ => error("Got an invalid argument: $arg")
9672
end
9773
elseif isconnector
9874
# Connectors can have variables listed without `@variables` prefix or
9975
# begin block.
100-
parse_variable_arg!(exprs, vs, dict, mod, arg, :variables, kwargs)
76+
parse_variable_arg!(exprs.args, vs, dict, mod, arg, :variables, kwargs)
10177
else
10278
error("$arg is not valid syntax. Expected a macro call.")
10379
end
@@ -108,13 +84,15 @@ function _model_macro(mod, name, expr, isconnector)
10884
iv = dict[:independent_variable] = variable(:t)
10985
end
11086

111-
push!(exprs.args, :(push!(systems, $(comps...))))
11287
push!(exprs.args, :(push!(equations, $(eqs...))))
88+
push!(exprs.args, :(push!(parameters, $(ps...))))
89+
push!(exprs.args, :(push!(systems, $(comps...))))
90+
push!(exprs.args, :(push!(variables, $(vs...))))
11391

11492
gui_metadata = isassigned(icon) > 0 ? GUIMetadata(GlobalRef(mod, name), icon[]) :
11593
GUIMetadata(GlobalRef(mod, name))
11694

117-
sys = :($ODESystem($Equation[equations...], $iv, [$(vs...)], [$(ps...)];
95+
sys = :($ODESystem($Equation[equations...], $iv, variables, parameters;
11896
name, systems, gui_metadata = $gui_metadata))
11997

12098
if ext[] === nothing
@@ -131,7 +109,7 @@ function _model_macro(mod, name, expr, isconnector)
131109
end
132110

133111
function parse_variable_def!(dict, mod, arg, varclass, kwargs;
134-
def = nothing, indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
112+
def = nothing, indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
135113
metatypes = [(:connection_type, VariableConnectType),
136114
(:description, VariableDescription),
137115
(:unit, VariableUnit),
@@ -166,12 +144,12 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs;
166144
Base.remove_linenums!(b)
167145
def, meta = parse_default(mod, b)
168146
var, def = parse_variable_def!(dict, mod, a, varclass, kwargs; def)
169-
dict[varclass][getname(var)][:default] = def
147+
dict[varclass][1][getname(var)][:default] = def
170148
if meta !== nothing
171149
for (type, key) in metatypes
172150
if (mt = get(meta, key, nothing)) !== nothing
173151
key == VariableConnectType && (mt = nameof(mt))
174-
dict[varclass][getname(var)][type] = mt
152+
dict[varclass][1][getname(var)][type] = mt
175153
end
176154
end
177155
var = set_var_metadata(var, meta)
@@ -185,7 +163,7 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs;
185163
for (type, key) in metatypes
186164
if (mt = get(meta, key, nothing)) !== nothing
187165
key == VariableConnectType && (mt = nameof(mt))
188-
dict[varclass][getname(var)][type] = mt
166+
dict[varclass][1][getname(var)][type] = mt
189167
end
190168
end
191169
var = set_var_metadata(var, meta)
@@ -196,12 +174,18 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs;
196174
parse_variable_def!(dict, mod, a, varclass, kwargs;
197175
def, indices = [eval.(b)...])
198176
end
177+
#= Expr(:if, condition, a) => begin
178+
var, def = [], []
179+
for var_def in a.args
180+
parse_variable_def!(dict, mod, var_def, varclass, kwargs)
181+
end
182+
end =#
199183
_ => error("$arg cannot be parsed")
200184
end
201185
end
202186

203187
function generate_var(a, varclass;
204-
indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
188+
indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
205189
var = indices === nothing ? Symbolics.variable(a) : first(@variables $a[indices...])
206190
if varclass == :parameters
207191
var = toparam(var)
@@ -210,25 +194,21 @@ function generate_var(a, varclass;
210194
end
211195

212196
function generate_var!(dict, a, varclass;
213-
indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
214-
vd = get!(dict, varclass) do
215-
Dict{Symbol, Dict{Symbol, Any}}()
216-
end
197+
indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
198+
vd = first(dict[varclass])
217199
vd[a] = Dict{Symbol, Any}()
218200
indices !== nothing && (vd[a][:size] = Tuple(lastindex.(indices)))
219201
generate_var(a, varclass; indices)
220202
end
221203

222204
function generate_var!(dict, a, b, varclass;
223-
indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
205+
indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
224206
iv = generate_var(b, :variables)
225207
prev_iv = get!(dict, :independent_variable) do
226208
iv
227209
end
228210
@assert isequal(iv, prev_iv)
229-
vd = get!(dict, varclass) do
230-
Dict{Symbol, Dict{Symbol, Any}}()
231-
end
211+
vd = first(dict[varclass])
232212
vd[a] = Dict{Symbol, Any}()
233213
var = if indices === nothing
234214
Symbolics.variable(a, T = SymbolicUtils.FnType{Tuple{Real}, Real})(iv)
@@ -291,7 +271,7 @@ function get_var(mod::Module, b)
291271
end
292272

293273
function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps,
294-
dict, mod, arg, kwargs)
274+
dict, mod, arg, kwargs)
295275
mname = arg.args[1]
296276
body = arg.args[end]
297277
if mname == Symbol("@components")
@@ -430,25 +410,87 @@ function parse_extend!(exprs, ext, dict, mod, body, kwargs)
430410
return nothing
431411
end
432412

433-
function parse_variable_arg!(expr, vs, dict, mod, arg, varclass, kwargs)
413+
function parse_variable_arg!(exprs, vs, dict, mod, arg, varclass, kwargs)
414+
name, ex = parse_variable_arg(dict, mod, arg, varclass, kwargs)
415+
push!(vs, name)
416+
push!(exprs, ex)
417+
end
418+
419+
function parse_variable_arg(dict, mod, arg, varclass, kwargs)
434420
vv, def = parse_variable_def!(dict, mod, arg, varclass, kwargs)
435421
name = getname(vv)
436-
push!(expr.args,
437-
:($name = $name === nothing ?
438-
$setdefault($vv, $def) :
439-
$setdefault($vv, $name)))
440-
vv isa Num ? push!(vs, name) : push!(vs, :($name...))
422+
return vv isa Num ? name : :($name...),
423+
:($name = $name === nothing ? $setdefault($vv, $def) : $setdefault($vv, $name))
424+
end
425+
426+
function handle_conditional_vars!(arg, conditional_branch, mod, varclass, kwargs)
427+
conditional_dict = Dict(:kwargs => Dict(),
428+
:parameters => Any[Dict{Symbol, Dict{Symbol, Any}}()],
429+
:variables => Any[Dict{Symbol, Dict{Symbol, Any}}()])
430+
for _arg in arg.args
431+
name, ex = parse_variable_arg(conditional_dict, mod, _arg, varclass, kwargs)
432+
push!(conditional_branch.args, ex)
433+
push!(conditional_branch.args, :(push!($varclass, $name)))
434+
end
435+
conditional_dict
441436
end
442437

443438
function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs)
444439
expr = Expr(:block)
445440
push!(exprs, expr)
446441
for arg in body.args
447442
arg isa LineNumberNode && continue
448-
parse_variable_arg!(expr, vs, dict, mod, arg, varclass, kwargs)
443+
MLStyle.@match arg begin
444+
Expr(:if, condition, x) => begin
445+
conditional_expr = Expr(:if, condition, Expr(:block))
446+
conditional_dict = handle_conditional_vars!(x,
447+
conditional_expr.args[2],
448+
mod,
449+
varclass,
450+
kwargs)
451+
push!(expr.args, conditional_expr)
452+
push!(dict[varclass], (:if, condition, conditional_dict, nothing))
453+
end
454+
Expr(:if, condition, x, y) => begin
455+
conditional_expr = Expr(:if, condition, Expr(:block))
456+
conditional_dict = handle_conditional_vars!(x,
457+
conditional_expr.args[2],
458+
mod,
459+
varclass,
460+
kwargs)
461+
conditional_y_expr, conditional_y_dict = handle_y_vars(y,
462+
conditional_dict,
463+
mod,
464+
varclass,
465+
kwargs)
466+
push!(conditional_expr.args, conditional_y_expr)
467+
push!(expr.args, conditional_expr)
468+
push!(dict[varclass],
469+
(:if, condition, conditional_dict, conditional_y_dict))
470+
end
471+
_ => parse_variable_arg!(exprs, vs, dict, mod, arg, varclass, kwargs)
472+
end
449473
end
450474
end
451475

476+
function handle_y_vars(y, dict, mod, varclass, kwargs)
477+
conditional_dict = if Meta.isexpr(y, :elseif)
478+
conditional_y_expr = Expr(:elseif, y.args[1], Expr(:block))
479+
conditional_dict = handle_conditional_vars!(y.args[2],
480+
conditional_y_expr.args[2],
481+
mod,
482+
varclass,
483+
kwargs)
484+
_y_expr, _conditional_dict = handle_y_vars(y.args[end], dict, mod, varclass, kwargs)
485+
push!(conditional_y_expr.args, _y_expr)
486+
(:elseif, y.args[1], conditional_dict, _conditional_dict)
487+
else
488+
conditional_y_expr = Expr(:block)
489+
handle_conditional_vars!(y, conditional_y_expr, mod, varclass, kwargs)
490+
end
491+
conditional_y_expr, conditional_dict
492+
end
493+
452494
function handle_if_x_equations!(ifexpr, condition, x)
453495
push!(ifexpr.args, condition, :(push!(equations, $(x.args...))))
454496
# push!(dict[:equations], [:if, readable_code(condition), readable_code.(x.args)])
@@ -561,7 +603,7 @@ function _parse_components!(exprs, body, kwargs)
561603
arg isa LineNumberNode && continue
562604
MLStyle.@match arg begin
563605
Expr(:block) => begin
564-
# TODO: Do we need this?
606+
# TODO: Do we need this?
565607
error("Multiple `@components` block detected within a single block")
566608
end
567609
Expr(:(=), a, b) => begin
@@ -570,6 +612,7 @@ function _parse_components!(exprs, body, kwargs)
570612

571613
component_args!(a, b, expr, varexpr, kwargs)
572614

615+
arg.args[2] = b
573616
push!(expr.args, arg)
574617
push!(comp_names, a)
575618
push!(comps, [a, b.args[1]])
@@ -645,7 +688,7 @@ function parse_components!(exprs, cs, dict, compbody, kwargs)
645688
$(expr_vec.args...)
646689
end))
647690
end
648-
_ => @info "410 Couldn't parse the component body $compbody" @__LINE__
691+
_ => error("Couldn't parse the component body $compbody")
649692
end
650693
end
651694
end
@@ -707,3 +750,27 @@ function parse_top_level_branch(condition, x, y = nothing, branch = :if)
707750

708751
return blocks
709752
end
753+
754+
function parse_conditional_model_statements(comps, dict, eqs, exprs, kwargs, mod, ps, vs,
755+
component_blk, equations_blk, parameter_blk, variable_blk)
756+
parameter_blk !== nothing &&
757+
parse_variables!(exprs.args, ps, dict, mod, :(begin
758+
$parameter_blk
759+
end), :parameters, kwargs)
760+
761+
variable_blk !== nothing &&
762+
parse_variables!(exprs.args, vs, dict, mod, :(begin
763+
$variable_blk
764+
end), :variables, kwargs)
765+
766+
component_blk !== nothing &&
767+
parse_components!(exprs.args,
768+
comps, dict, :(begin
769+
$component_blk
770+
end), kwargs)
771+
772+
equations_blk !== nothing &&
773+
parse_equations!(exprs.args, eqs, dict, :(begin
774+
$equations_blk
775+
end))
776+
end

0 commit comments

Comments
 (0)