Skip to content

Commit 0ae2810

Browse files
committed
fix: make :parameters and :variables in Model.structure backwards compatible
- all variables are added as a key and `:condition` is added as a metadata for the conditional variables. - The `:condition` contains entire if-else block info as a tuple of (condition-branch, condition, variables-if-correct, variable-if-condition-isn't-met). - variable-if-condition-isn't-met is nothing or tuple similar to the one above.
1 parent 16025fd commit 0ae2810

File tree

1 file changed

+97
-25
lines changed

1 file changed

+97
-25
lines changed

src/systems/model_parsing.jl

Lines changed: 97 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ function _model_macro(mod, name, expr, isconnector)
3737
exprs = Expr(:block)
3838
dict = Dict{Symbol, Any}()
3939
dict[:kwargs] = Dict{Symbol, Any}()
40-
dict[:parameters] = Any[Dict{Symbol, Dict{Symbol, Any}}()]
41-
dict[:variables] = Any[Dict{Symbol, Dict{Symbol, Any}}()]
4240
comps = Symbol[]
4341
ext = Ref{Any}(nothing)
4442
eqs = Expr[]
@@ -109,7 +107,7 @@ function _model_macro(mod, name, expr, isconnector)
109107
end
110108

111109
function parse_variable_def!(dict, mod, arg, varclass, kwargs;
112-
def = nothing, indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
110+
def = nothing, indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
113111
metatypes = [(:connection_type, VariableConnectType),
114112
(:description, VariableDescription),
115113
(:unit, VariableUnit),
@@ -144,12 +142,16 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs;
144142
Base.remove_linenums!(b)
145143
def, meta = parse_default(mod, b)
146144
var, def = parse_variable_def!(dict, mod, a, varclass, kwargs; def)
147-
dict[varclass][1][getname(var)][:default] = def
145+
dict[varclass][getname(var)][:default] = def
148146
if meta !== nothing
149147
for (type, key) in metatypes
150148
if (mt = get(meta, key, nothing)) !== nothing
151149
key == VariableConnectType && (mt = nameof(mt))
152-
dict[varclass][1][getname(var)][type] = mt
150+
if dict[varclass] isa Vector
151+
dict[varclass][1][getname(var)][type] = mt
152+
else
153+
dict[varclass][getname(var)][type] = mt
154+
end
153155
end
154156
end
155157
var = set_var_metadata(var, meta)
@@ -163,7 +165,12 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs;
163165
for (type, key) in metatypes
164166
if (mt = get(meta, key, nothing)) !== nothing
165167
key == VariableConnectType && (mt = nameof(mt))
166-
dict[varclass][1][getname(var)][type] = mt
168+
# @info dict 164
169+
if dict[varclass] isa Vector
170+
dict[varclass][1][getname(var)][type] = mt
171+
else
172+
dict[varclass][getname(var)][type] = mt
173+
end
167174
end
168175
end
169176
var = set_var_metadata(var, meta)
@@ -185,7 +192,7 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs;
185192
end
186193

187194
function generate_var(a, varclass;
188-
indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
195+
indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
189196
var = indices === nothing ? Symbolics.variable(a) : first(@variables $a[indices...])
190197
if varclass == :parameters
191198
var = toparam(var)
@@ -194,21 +201,27 @@ function generate_var(a, varclass;
194201
end
195202

196203
function generate_var!(dict, a, varclass;
197-
indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
198-
vd = first(dict[varclass])
204+
indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
205+
vd = get!(dict, varclass) do
206+
Dict{Symbol, Dict{Symbol, Any}}()
207+
end
208+
vd isa Vector && (vd = first(vd))
199209
vd[a] = Dict{Symbol, Any}()
200210
indices !== nothing && (vd[a][:size] = Tuple(lastindex.(indices)))
201211
generate_var(a, varclass; indices)
202212
end
203213

204214
function generate_var!(dict, a, b, varclass;
205-
indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
215+
indices::Union{Vector{UnitRange{Int}}, Nothing} = nothing)
206216
iv = generate_var(b, :variables)
207217
prev_iv = get!(dict, :independent_variable) do
208218
iv
209219
end
210-
@assert isequal(iv, prev_iv)
211-
vd = first(dict[varclass])
220+
@assert isequal(iv, prev_iv) "Multiple independent variables are used in the model"
221+
vd = get!(dict, varclass) do
222+
Dict{Symbol, Dict{Symbol, Any}}()
223+
end
224+
vd isa Vector && (vd = first(vd))
212225
vd[a] = Dict{Symbol, Any}()
213226
var = if indices === nothing
214227
Symbolics.variable(a, T = SymbolicUtils.FnType{Tuple{Real}, Real})(iv)
@@ -271,7 +284,7 @@ function get_var(mod::Module, b)
271284
end
272285

273286
function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps,
274-
dict, mod, arg, kwargs)
287+
dict, mod, arg, kwargs)
275288
mname = arg.args[1]
276289
body = arg.args[end]
277290
if mname == Symbol("@components")
@@ -435,6 +448,57 @@ function handle_conditional_vars!(arg, conditional_branch, mod, varclass, kwargs
435448
conditional_dict
436449
end
437450

451+
function prune_conditional_dict!(conditional_tuple::Tuple)
452+
prune_conditional_dict!.(collect(conditional_tuple))
453+
end
454+
function prune_conditional_dict!(conditional_dict::Dict)
455+
for k in [:parameters, :variables]
456+
length(conditional_dict[k]) == 1 && isempty(first(conditional_dict[k])) &&
457+
delete!(conditional_dict, k)
458+
end
459+
isempty(conditional_dict[:kwargs]) && delete!(conditional_dict, :kwargs)
460+
end
461+
prune_conditional_dict!(_) = return nothing
462+
463+
function get_conditional_dict!(conditional_dict, conditional_y_tuple::Tuple)
464+
k = get_conditional_dict!.(Ref(conditional_dict), collect(conditional_y_tuple))
465+
push_something!(conditional_dict,
466+
k...)
467+
conditional_dict
468+
end
469+
470+
function get_conditional_dict!(conditional_dict::Dict, conditional_y_tuple::Dict)
471+
merge!(conditional_dict[:kwargs], conditional_y_tuple[:kwargs])
472+
for key in [:parameters, :variables]
473+
merge!(conditional_dict[key][1], conditional_y_tuple[key][1])
474+
end
475+
conditional_dict
476+
end
477+
478+
get_conditional_dict!(a, b) = (return nothing)
479+
480+
function push_conditional_dict!(dict, condition, conditional_dict,
481+
conditional_y_tuple, varclass)
482+
vd = get!(dict, varclass) do
483+
Dict{Symbol, Dict{Symbol, Any}}()
484+
end
485+
for k in keys(conditional_dict[varclass][1])
486+
vd[k] = copy(conditional_dict[varclass][1][k])
487+
vd[k][:condition] = (:if, condition, conditional_dict, conditional_y_tuple)
488+
end
489+
conditional_y_dict = Dict(:kwargs => Dict(),
490+
:parameters => Any[Dict{Symbol, Dict{Symbol, Any}}()],
491+
:variables => Any[Dict{Symbol, Dict{Symbol, Any}}()])
492+
get_conditional_dict!(conditional_y_dict, conditional_y_tuple)
493+
494+
prune_conditional_dict!(conditional_y_dict)
495+
prune_conditional_dict!(conditional_dict)
496+
!isempty(conditional_y_dict) && for k in keys(conditional_y_dict[varclass][1])
497+
vd[k] = copy(conditional_y_dict[varclass][1][k])
498+
vd[k][:condition] = (:if, condition, conditional_dict, conditional_y_tuple)
499+
end
500+
end
501+
438502
function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs)
439503
expr = Expr(:block)
440504
push!(exprs, expr)
@@ -449,7 +513,7 @@ function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs)
449513
varclass,
450514
kwargs)
451515
push!(expr.args, conditional_expr)
452-
push!(dict[varclass], (:if, condition, conditional_dict, nothing))
516+
push_conditional_dict!(dict, condition, conditional_dict, nothing, varclass)
453517
end
454518
Expr(:if, condition, x, y) => begin
455519
conditional_expr = Expr(:if, condition, Expr(:block))
@@ -458,15 +522,18 @@ function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs)
458522
mod,
459523
varclass,
460524
kwargs)
461-
conditional_y_expr, conditional_y_dict = handle_y_vars(y,
525+
conditional_y_expr, conditional_y_tuple = handle_y_vars(y,
462526
conditional_dict,
463527
mod,
464528
varclass,
465529
kwargs)
466530
push!(conditional_expr.args, conditional_y_expr)
467531
push!(expr.args, conditional_expr)
468-
push!(dict[varclass],
469-
(:if, condition, conditional_dict, conditional_y_dict))
532+
push_conditional_dict!(dict,
533+
condition,
534+
conditional_dict,
535+
conditional_y_tuple,
536+
varclass)
470537
end
471538
_ => parse_variable_arg!(exprs, vs, dict, mod, arg, varclass, kwargs)
472539
end
@@ -492,10 +559,11 @@ function handle_y_vars(y, dict, mod, varclass, kwargs)
492559
end
493560

494561
function handle_if_x_equations!(condition, dict, ifexpr, x)
495-
push!(ifexpr.args, condition, :(push!(equations, $(x.args...))))
496-
# push!(dict[:equations], [:if, readable_code(condition), readable_code.(x.args)])
497-
readable_code.(x.args)
562+
push!(ifexpr.args, condition, :(push!(equations, $(x.args...))))
563+
# push!(dict[:equations], [:if, readable_code(condition), readable_code.(x.args)])
564+
readable_code.(x.args)
498565
end
566+
499567
function handle_if_y_equations!(ifexpr, y, dict)
500568
if y.head == :elseif
501569
elseifexpr = Expr(:elseif)
@@ -506,10 +574,11 @@ function handle_if_y_equations!(ifexpr, y, dict)
506574
push!(ifexpr.args, elseifexpr)
507575
(eq_entry...,)
508576
else
509-
push!(ifexpr.args, :(push!(equations, $(y.args...))))
510-
readable_code.(y.args)
577+
push!(ifexpr.args, :(push!(equations, $(y.args...))))
578+
readable_code.(y.args)
511579
end
512580
end
581+
513582
function parse_equations!(exprs, eqs, dict, body)
514583
dict[:equations] = []
515584
Base.remove_linenums!(body)
@@ -699,6 +768,7 @@ end
699768
# Handle top level branching
700769
push_something!(v, ::Nothing) = v
701770
push_something!(v, x) = push!(v, x)
771+
push_something!(v::Dict, x::Dict) = merge!(v, x)
702772
push_something!(v, x...) = push_something!.(Ref(v), x)
703773

704774
define_blocks(branch) = [Expr(branch), Expr(branch), Expr(branch), Expr(branch)]
@@ -737,21 +807,23 @@ function parse_top_level_branch(condition, x, y = nothing, branch = :if)
737807
for i in 1:lastindex(yblocks)
738808
if lastindex(blocks[i].args) == 1
739809
push_something!(blocks[i].args, Expr(:block), yblocks[i])
810+
elseif lastindex(blocks[i].args) == 0
811+
blocks[i] = yblocks[i]
740812
else
741813
push_something!(blocks[i].args, yblocks[i])
742814
end
743815
end
744816
end
745817

746818
for i in 1:lastindex(blocks)
747-
isempty(blocks[i].args) && (blocks[i] = nothing)
819+
blocks[i] !== nothing && isempty(blocks[i].args) && (blocks[i] = nothing)
748820
end
749821

750822
return blocks
751823
end
752824

753-
function parse_conditional_model_statements(comps, dict, eqs, exprs, kwargs, mod, ps, vs,
754-
component_blk, equations_blk, parameter_blk, variable_blk)
825+
function parse_conditional_model_statements(comps, dict, eqs, exprs, kwargs, mod,
826+
ps, vs, component_blk, equations_blk, parameter_blk, variable_blk)
755827
parameter_blk !== nothing &&
756828
parse_variables!(exprs.args, ps, dict, mod, :(begin
757829
$parameter_blk

0 commit comments

Comments
 (0)