Skip to content

Commit ecd28d6

Browse files
authored
Merge pull request #1122 from vyudu/dsl-no-infer
Add @no_infer flag for turning off species/variable/parameter inferring
2 parents ff149fd + c710814 commit ecd28d6

File tree

2 files changed

+114
-17
lines changed

2 files changed

+114
-17
lines changed

src/dsl.jl

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ const pure_rate_arrows = Set{Symbol}([:(=>), :(<=), :⇐, :⟽, :⇒, :⟾, :⇔
7171
# Declares the keys used for various options.
7272
const option_keys = (:species, :parameters, :variables, :ivs, :compounds, :observables,
7373
:default_noise_scaling, :differentials, :equations,
74-
:continuous_events, :discrete_events, :combinatoric_ratelaws)
74+
:continuous_events, :discrete_events, :combinatoric_ratelaws, :require_declaration)
7575

7676
### `@species` Macro ###
7777

@@ -220,13 +220,14 @@ struct ReactionStruct
220220
products::Vector{ReactantStruct}
221221
rate::ExprValues
222222
metadata::Expr
223+
rxexpr::Expr
223224

224225
function ReactionStruct(sub_line::ExprValues, prod_line::ExprValues, rate::ExprValues,
225-
metadata_line::ExprValues)
226+
metadata_line::ExprValues, rx_line::Expr)
226227
sub = recursive_find_reactants!(sub_line, 1, Vector{ReactantStruct}(undef, 0))
227228
prod = recursive_find_reactants!(prod_line, 1, Vector{ReactantStruct}(undef, 0))
228229
metadata = extract_metadata(metadata_line)
229-
new(sub, prod, rate, metadata)
230+
new(sub, prod, rate, metadata, rx_line)
230231
end
231232
end
232233

@@ -283,6 +284,17 @@ function extract_metadata(metadata_line::Expr)
283284
return metadata
284285
end
285286

287+
288+
289+
struct UndeclaredSymbolicError <: Exception
290+
msg::String
291+
end
292+
293+
function Base.showerror(io::IO, err::UndeclaredSymbolicError)
294+
print(io, "UndeclaredSymbolicError: ")
295+
print(io, err.msg)
296+
end
297+
286298
### DSL Internal Master Function ###
287299

288300
# Function for creating a ReactionSystem structure (used by the @reaction_network macro).
@@ -308,6 +320,7 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
308320
compound_expr, compound_species = read_compound_options(options)
309321
continuous_events_expr = read_events_option(options, :continuous_events)
310322
discrete_events_expr = read_events_option(options, :discrete_events)
323+
requiredec = haskey(options, :require_declaration)
311324

312325
# Parses reactions, species, and parameters.
313326
reactions = get_reactions(reaction_lines)
@@ -317,7 +330,7 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
317330

318331
# Reads equations.
319332
vars_extracted, add_default_diff, equations = read_equations_options(
320-
options, variables_declared)
333+
options, variables_declared; requiredec)
321334
variables = vcat(variables_declared, vars_extracted)
322335

323336
# Handle independent variables
@@ -341,13 +354,13 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
341354

342355
# Reads observables.
343356
observed_vars, observed_eqs, obs_syms = read_observed_options(
344-
options, [species_declared; variables], all_ivs)
357+
options, [species_declared; variables], all_ivs; requiredec)
345358

346359
# Collect species and parameters, including ones inferred from the reactions.
347360
declared_syms = Set(Iterators.flatten((parameters_declared, species_declared,
348361
variables)))
349362
species_extracted, parameters_extracted = extract_species_and_parameters!(
350-
reactions, declared_syms)
363+
reactions, declared_syms; requiredec)
351364

352365
species = vcat(species_declared, species_extracted)
353366
parameters = vcat(parameters_declared, parameters_extracted)
@@ -425,15 +438,15 @@ function get_reactions(exprs::Vector{Expr}, reactions = Vector{ReactionStruct}(u
425438
error("Error: Must provide a tuple of reaction rates when declaring a bi-directional reaction.")
426439
end
427440
push_reactions!(reactions, reaction.args[2], reaction.args[3],
428-
rate.args[1], metadata.args[1], arrow)
441+
rate.args[1], metadata.args[1], arrow, line)
429442
push_reactions!(reactions, reaction.args[3], reaction.args[2],
430-
rate.args[2], metadata.args[2], arrow)
443+
rate.args[2], metadata.args[2], arrow, line)
431444
elseif in(arrow, fwd_arrows)
432445
push_reactions!(reactions, reaction.args[2], reaction.args[3],
433-
rate, metadata, arrow)
446+
rate, metadata, arrow, line)
434447
elseif in(arrow, bwd_arrows)
435448
push_reactions!(reactions, reaction.args[3], reaction.args[2],
436-
rate, metadata, arrow)
449+
rate, metadata, arrow, line)
437450
else
438451
throw("Malformed reaction, invalid arrow type used in: $(MacroTools.striplines(line))")
439452
end
@@ -467,7 +480,7 @@ end
467480
# Takes a reaction line and creates reaction(s) from it and pushes those to the reaction array.
468481
# Used to create multiple reactions from, for instance, `k, (X,Y) --> 0`.
469482
function push_reactions!(reactions::Vector{ReactionStruct}, sub_line::ExprValues,
470-
prod_line::ExprValues, rate::ExprValues, metadata::ExprValues, arrow::Symbol)
483+
prod_line::ExprValues, rate::ExprValues, metadata::ExprValues, arrow::Symbol, line::Expr)
471484
# The rates, substrates, products, and metadata may be in a tupple form (e.g. `k, (X,Y) --> 0`).
472485
# This finds the length of these tuples (or 1 if not in tuple forms). Errors if lengs inconsistent.
473486
lengs = (tup_leng(sub_line), tup_leng(prod_line), tup_leng(rate), tup_leng(metadata))
@@ -490,7 +503,7 @@ function push_reactions!(reactions::Vector{ReactionStruct}, sub_line::ExprValues
490503

491504
push!(reactions,
492505
ReactionStruct(get_tup_arg(sub_line, i),
493-
get_tup_arg(prod_line, i), get_tup_arg(rate, i), metadata_i))
506+
get_tup_arg(prod_line, i), get_tup_arg(rate, i), metadata_i, line))
494507
end
495508
end
496509

@@ -511,20 +524,26 @@ end
511524

512525
# Function looping through all reactions, to find undeclared symbols (species or
513526
# parameters), and assign them to the right category.
514-
function extract_species_and_parameters!(reactions, excluded_syms)
527+
function extract_species_and_parameters!(reactions, excluded_syms; requiredec = false)
515528
species = OrderedSet{Union{Symbol, Expr}}()
516529
for reaction in reactions
517530
for reactant in Iterators.flatten((reaction.substrates, reaction.products))
518531
add_syms_from_expr!(species, reactant.reactant, excluded_syms)
532+
(!isempty(species) && requiredec) && throw(UndeclaredSymbolicError(
533+
"Unrecognized variables $(join(species, ", ")) detected in reaction expression: \"$(string(reaction.rxexpr))\". Since the flag @require_declaration is declared, all species must be explicitly declared with the @species macro."))
519534
end
520535
end
521536

522537
foreach(s -> push!(excluded_syms, s), species)
523538
parameters = OrderedSet{Union{Symbol, Expr}}()
524539
for reaction in reactions
525540
add_syms_from_expr!(parameters, reaction.rate, excluded_syms)
541+
(!isempty(parameters) && requiredec) && throw(UndeclaredSymbolicError(
542+
"Unrecognized parameter $(join(parameters, ", ")) detected in rate expression: $(reaction.rate) for the following reaction expression: \"$(string(reaction.rxexpr))\". Since the flag @require_declaration is declared, all parameters must be explicitly declared with the @parameters macro."))
526543
for reactant in Iterators.flatten((reaction.substrates, reaction.products))
527544
add_syms_from_expr!(parameters, reactant.stoichiometry, excluded_syms)
545+
(!isempty(parameters) && requiredec) && throw(UndeclaredSymbolicError(
546+
"Unrecognized parameters $(join(parameters, ", ")) detected in the stoichiometry for reactant $(reactant.reactant) in the following reaction expression: \"$(string(reaction.rxexpr))\". Since the flag @require_declaration is declared, all parameters must be explicitly declared with the @parameters macro."))
528547
end
529548
end
530549

@@ -682,7 +701,7 @@ end
682701
# `vars_extracted`: A vector with extracted variables (lhs in pure differential equations only).
683702
# `dtexpr`: If a differential equation is defined, the default derivative (D ~ Differential(t)) must be defined.
684703
# `equations`: a vector with the equations provided.
685-
function read_equations_options(options, variables_declared)
704+
function read_equations_options(options, variables_declared; requiredec = false)
686705
# Prepares the equations. First, extracts equations from provided option (converting to block form if required).
687706
# Next, uses MTK's `parse_equations!` function to split input into a vector with the equations.
688707
eqs_input = haskey(options, :equations) ? options[:equations].args[3] : :(begin end)
@@ -711,9 +730,13 @@ function read_equations_options(options, variables_declared)
711730
diff_var = lhs.args[2]
712731
if in(diff_var, forbidden_symbols_error)
713732
error("A forbidden symbol ($(diff_var)) was used as an variable in this differential equation: $eq")
733+
elseif (!in(diff_var, variables_declared)) && requiredec
734+
throw(UndeclaredSymbolicError(
735+
"Unrecognized symbol $(diff_var) was used as a variable in an equation: \"$eq\". Since the @require_declaration flag is set, all variables in equations must be explicitly declared via @variables, @species, or @parameters."))
736+
else
737+
add_default_diff = true
738+
in(diff_var, variables_declared) || push!(vars_extracted, diff_var)
714739
end
715-
add_default_diff = true
716-
in(diff_var, variables_declared) || push!(vars_extracted, diff_var)
717740
end
718741
end
719742

@@ -752,7 +775,7 @@ function create_differential_expr(options, add_default_diff, used_syms, tiv)
752775
end
753776

754777
# Reads the observables options. Outputs an expression ofr creating the observable variables, and a vector of observable equations.
755-
function read_observed_options(options, species_n_vars_declared, ivs_sorted)
778+
function read_observed_options(options, species_n_vars_declared, ivs_sorted; requiredec = false)
756779
if haskey(options, :observables)
757780
# Gets list of observable equations and prepares variable declaration expression.
758781
# (`options[:observables]` includes `@observables`, `.args[3]` removes this part)
@@ -763,6 +786,10 @@ function read_observed_options(options, species_n_vars_declared, ivs_sorted)
763786
for (idx, obs_eq) in enumerate(observed_eqs.args)
764787
# Extract the observable, checks errors, and continues the loop if the observable has been declared.
765788
obs_name, ivs, defaults, metadata = find_varinfo_in_declaration(obs_eq.args[2])
789+
if (requiredec && !in(obs_name, species_n_vars_declared))
790+
throw(UndeclaredSymbolicError(
791+
"An undeclared variable ($obs_name) was declared as an observable in the following observable equation: \"$obs_eq\". Since the flag @require_declaration is set, all variables must be declared with the @species, @parameters, or @variables macros."))
792+
end
766793
isempty(ivs) ||
767794
error("An observable ($obs_name) was given independent variable(s). These should not be given, as they are inferred automatically.")
768795
isnothing(defaults) ||

test/dsl/dsl_options.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,3 +1022,73 @@ let
10221022
@parameters v n
10231023
@test isequal(Catalyst.expand_registered_functions(equations(rn4)[1]), D(A) ~ v*(A^n))
10241024
end
1025+
1026+
### test that @no_infer properly throws errors when undeclared variables are written
1027+
1028+
import Catalyst: UndeclaredSymbolicError
1029+
let
1030+
# Test error when species are inferred
1031+
@test_throws UndeclaredSymbolicError @macroexpand @reaction_network begin
1032+
@require_declaration
1033+
@parameters k
1034+
k, A --> B
1035+
end
1036+
@test_nowarn @macroexpand @reaction_network begin
1037+
@require_declaration
1038+
@species A(t) B(t)
1039+
@parameters k
1040+
k, A --> B
1041+
end
1042+
1043+
# Test error when a parameter in rate is inferred
1044+
@test_throws UndeclaredSymbolicError @macroexpand @reaction_network begin
1045+
@require_declaration
1046+
@species A(t) B(t)
1047+
@parameters k
1048+
k*n, A --> B
1049+
end
1050+
@test_nowarn @macroexpand @reaction_network begin
1051+
@require_declaration
1052+
@parameters n k
1053+
@species A(t) B(t)
1054+
k*n, A --> B
1055+
end
1056+
1057+
# Test error when a parameter in stoichiometry is inferred
1058+
@test_throws UndeclaredSymbolicError @macroexpand @reaction_network begin
1059+
@require_declaration
1060+
@parameters k
1061+
@species A(t) B(t)
1062+
k, n*A --> B
1063+
end
1064+
@test_nowarn @macroexpand @reaction_network begin
1065+
@require_declaration
1066+
@parameters k n
1067+
@species A(t) B(t)
1068+
k, n*A --> B
1069+
end
1070+
1071+
# Test error when a variable in an equation is inferred
1072+
@test_throws UndeclaredSymbolicError @macroexpand @reaction_network begin
1073+
@require_declaration
1074+
@equations D(V) ~ V^2
1075+
end
1076+
@test_nowarn @macroexpand @reaction_network begin
1077+
@require_declaration
1078+
@variables V(t)
1079+
@equations D(V) ~ V^2
1080+
end
1081+
1082+
# Test error when a variable in an observable is inferred
1083+
@test_throws UndeclaredSymbolicError @macroexpand @reaction_network begin
1084+
@require_declaration
1085+
@variables X1(t)
1086+
@observables X2 ~ X1
1087+
end
1088+
@test_nowarn @macroexpand @reaction_network begin
1089+
@require_declaration
1090+
@variables X1(t) X2(t)
1091+
@observables X2 ~ X1
1092+
end
1093+
end
1094+

0 commit comments

Comments
 (0)