Skip to content

Commit aac44f2

Browse files
committed
fixes
1 parent ecea402 commit aac44f2

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

src/dsl.jl

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ end
144144
function make_rs_expr(name; complete = true)
145145
rs_expr = :(ReactionSystem(Reaction[], t, [], []; name = $name))
146146
complete && (rs_expr = :(complete($rs_expr)))
147-
return Expr(:block, :(@parameters t), rs_expr)
147+
return Expr(:block, :(t = default_t()), rs_expr)
148148
end
149149

150150
# When both a name and a network expression are generated, dispatch these to the internal
@@ -270,12 +270,12 @@ function make_reaction_system(ex::Expr, name)
270270
# Extract the lines with reactions, the lines with options, and the options. Check for input errors.
271271
reaction_lines = Expr[x for x in ex.args if x.head == :tuple]
272272
option_lines = Expr[x for x in ex.args if x.head == :macrocall]
273-
options = Dict(Symbol(String(arg.args[1])[2:end]) => arg for arg in option_lines)
274273
allunique(arg.args[1] for arg in option_lines) ||
275274
error("Some options where given multiple times.")
276275
numlines = length(reaction_lines) + length(option_lines)
277276
(numlines != length(ex.args)) &&
278-
error("@reaction_network input contain $(length(ex.args) - $numlines) malformed lines.")
277+
error("@reaction_network input contain $(length(ex.args) - numlines) malformed lines.")
278+
options = Dict(Symbol(String(arg.args[1])[2:end]) => arg for arg in option_lines)
279279
any(!in(option_keys), keys(options)) &&
280280
error("The following unsupported options were used: $(filter(opt_in->!in(opt_in,option_keys), keys(options)))")
281281

@@ -316,7 +316,7 @@ function make_reaction_system(ex::Expr, name)
316316
psexpr_init = get_psexpr(ps_inferred, options)
317317
spsexpr_init = get_usexpr(sps_inferred, options; ivs)
318318
vsexpr_init = get_usexpr(vs_inferred, options, :variables; ivs)
319-
psexpr, psvar = assign_var_to_symvar_declaration(psexpr_init, "ps")
319+
psexpr, psvar = assign_var_to_symvar_declaration(psexpr_init, "ps", scalarize = false)
320320
spsexpr, spsvar = assign_var_to_symvar_declaration(spsexpr_init, "specs")
321321
vsexpr, vsvar = assign_var_to_symvar_declaration(vsexpr_init, "vars")
322322
cmpsexpr, cmpsvar = assign_var_to_symvar_declaration(cmpexpr_init, "comps")
@@ -328,8 +328,8 @@ function make_reaction_system(ex::Expr, name)
328328
# Inserts the expressions which generate the `ReactionSystem` input.
329329
$ivsexpr
330330
$psexpr
331-
$spsexpr
332331
$vsexpr
332+
$spsexpr
333333
$obsexpr
334334
$cmpsexpr
335335
$diffsexpr
@@ -419,12 +419,12 @@ function push_reactions!(reactions::Vector{DSLReaction}, subs::ExprValues,
419419
# This finds these tuples' lengths (or 1 for non-tuple forms). Inconsistent lengths yield error.
420420
lengs = (tup_leng(subs), tup_leng(prods), tup_leng(rate), tup_leng(metadata))
421421
maxlen = maximum(lengs)
422-
any(!(leng == 1 || leng == maxl) for leng in lengs) &&
422+
any(!(leng == 1 || leng == maxlen) for leng in lengs) &&
423423
error("Malformed reaction, rate: $rate, subs: $subs, prods: $prods, metadata: $metadata.")
424424

425425
# Loops through each reaction encoded by the reaction's different components.
426426
# Creates a `DSLReaction` representation and adds it to `reactions`.
427-
for i in 1:maxl
427+
for i in 1:maxlen
428428
# If the `only_use_rate` metadata was not provided, this must be inferred from the arrow.
429429
metadata_i = get_tup_arg(metadata, i)
430430
if all(arg.args[1] != :only_use_rate for arg in metadata_i.args)
@@ -488,8 +488,8 @@ function extract_sps_and_ps(reactions, excluded_syms; requiredec = false)
488488
collect(species), collect(parameters)
489489
end
490490

491-
# Function called by `extract_sps_and_ps`, recursively loops through an
492-
# expression and find symbols (adding them to the push_symbols vector).
491+
# Function called by `extract_sps_and_ps`, recursively loops through an expression and find
492+
# symbols (adding them to the push_symbols vector). Returns `nothing` to ensure type stability.
493493
function add_syms_from_expr!(push_symbols::AbstractSet, expr::ExprValues, excluded_syms)
494494
# If we have encountered a Symbol in the recursion, we can try extracting it.
495495
if expr isa Symbol
@@ -502,6 +502,7 @@ function add_syms_from_expr!(push_symbols::AbstractSet, expr::ExprValues, exclud
502502
add_syms_from_expr!(push_symbols, expr.args[i], excluded_syms)
503503
end
504504
end
505+
nothing
505506
end
506507

507508
### DSL Output Expression Builders ###
@@ -575,22 +576,28 @@ end
575576
# That calls the macro and then scalarizes all the symbols created into a vector of Nums.
576577
# stores the created symbolic variables in a variable (whose name is generated from `name`).
577578
# It will also return the name used for the variable that stores the symbolic variables.
578-
function assign_var_to_symvar_declaration(expr_init, name)
579+
# If requested, performs scalarization.
580+
function assign_var_to_symvar_declaration(expr_init, name; scalarize = true)
579581
# Generates a random variable name which (in generated code) will store the produced
580582
# symbolic variables (e.g. `var"##ps#384"`).
581583
namesym = gensym(name)
582584

583585
# If the input expression is non-empty, wrap it with additional information.
584586
if expr_init != :(())
585-
symvec = gensym()
586-
expr = quote
587-
$symvec = $expr_init
588-
$namesym = reduce(vcat, Symbolics.scalarize($symvec))
587+
if scalarize
588+
symvec = gensym()
589+
expr = quote
590+
$symvec = $expr_init
591+
$namesym = reduce(vcat, Symbolics.scalarize($symvec))
592+
end
593+
else
594+
expr = quote
595+
$namesym = $expr_init
596+
end
589597
end
590598
else
591599
expr = :($namesym = Num[])
592600
end
593-
594601
return expr, namesym
595602
end
596603

@@ -855,7 +862,7 @@ function read_ivs_option(options)
855862
if haskey(options, :ivs)
856863
ivs = Tuple(extract_syms(options, :ivs))
857864
ivsexpr = copy(options[:ivs])
858-
ivsexpr.args[1] = Symbol("@", "parameters")
865+
ivsexpr.args[1] = Symbol("@", "independent_variables")
859866
else
860867
ivs = (DEFAULT_IV_SYM,)
861868
ivsexpr = :($(DEFAULT_IV_SYM) = default_t())

0 commit comments

Comments
 (0)