Skip to content

Commit 1205ea1

Browse files
authored
Merge pull request #1175 from SciML/change_equation_inference
Update inference of variables and default differential from `@equations` macro
2 parents 41ce941 + 64f1e39 commit 1205ea1

File tree

6 files changed

+278
-97
lines changed

6 files changed

+278
-97
lines changed

HISTORY.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,29 @@
77
(at the time the release is made). If you need a dependency version increased,
88
please open an issue and we can update it and make a new Catalyst release once
99
testing against the newer dependency version is complete.
10+
- New formula for inferring variables from equations (declared using the `@equations` options) in the DSL. The order of inference of species/variables/parameters is now:
11+
(1) Every symbol explicitly declared using `@species`, `@variables`, and `@parameters` are assigned to the correct category.
12+
(2) Every symbol used as a reaction reactant is inferred as a species.
13+
(3) Every symbol not declared in (1) or (2) that occurs in an expression provided after `@equations` is inferred as a variable.
14+
(4) Every symbol not declared in (1), (2), or (3) that occurs either as a reaction rate or stoichiometric coefficient is inferred to be a parameter.
15+
E.g. in
16+
```julia
17+
@reaction_network begin
18+
@equations V1 + S ~ V2^2
19+
(p + S + V1), S --> 0
20+
end
21+
```
22+
`S` is inferred as a species, `V1` and `V2` as variables, and `p` as a parameter. The previous special cases for the `@observables`, `@compounds`, and `@differentials` options still hold. Finally, the `@require_declaration` options (described in more detail below) can now be used to require everything to be explicitly declared.
23+
- New formula for determining whether the default differentials have been used within an `@equations` option. Now, if any expression `D(...)` is encountered (where `...` can be anything), this is inferred as usage of the default differential D. E.g. in the following equations `D` is inferred as a differential with respect to the default independent variable:
24+
```julia
25+
@reaction_network begin
26+
@equations D(V) + V ~ 1
27+
end
28+
@reaction_network begin
29+
@equations D(D(V)) ~ 1
30+
end
31+
```
32+
Please note that this cannot be used at the same time as `D` is used to represent a species, variable, or parameter (including is these are implicitly designated as such by e.g. appearing as a reaction reactant).
1033
- Array symbolics support is more consistent with ModelingToolkit v9. Parameter
1134
arrays are no longer scalarized by Catalyst, while species and variables
1235
arrays still are (as in ModelingToolkit). As such, parameter arrays should now

docs/src/inverse_problems/examples/ode_fitting_oscillation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function optimise_p(pinit, tend)
5656
newprob = remake(prob; tspan = (0.0, tend), p = p)
5757
sol = Array(solve(newprob, Rosenbrock23(); saveat = newtimes))
5858
loss = sum(abs2, sol .- sample_vals[:, 1:size(sol,2)])
59-
return loss, sol
59+
return loss
6060
end
6161
6262
# optimize for the parameters that minimize the loss

src/dsl.jl

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ struct UndeclaredSymbolicError <: Exception
290290
msg::String
291291
end
292292

293-
function Base.showerror(io::IO, err::UndeclaredSymbolicError)
293+
function Base.showerror(io::IO, err::UndeclaredSymbolicError)
294294
print(io, "UndeclaredSymbolicError: ")
295295
print(io, err.msg)
296296
end
@@ -328,11 +328,6 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
328328
parameters_declared = extract_syms(options, :parameters)
329329
variables_declared = extract_syms(options, :variables)
330330

331-
# Reads equations.
332-
vars_extracted, add_default_diff, equations = read_equations_options(
333-
options, variables_declared; requiredec)
334-
variables = vcat(variables_declared, vars_extracted)
335-
336331
# Handle independent variables
337332
if haskey(options, :ivs)
338333
ivs = Tuple(extract_syms(options, :ivs))
@@ -352,23 +347,32 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
352347
combinatoric_ratelaws = true
353348
end
354349

355-
# Reads observables.
356-
observed_vars, observed_eqs, obs_syms = read_observed_options(
357-
options, [species_declared; variables], all_ivs; requiredec)
358-
359350
# Collect species and parameters, including ones inferred from the reactions.
360351
declared_syms = Set(Iterators.flatten((parameters_declared, species_declared,
361-
variables)))
352+
variables_declared)))
362353
species_extracted, parameters_extracted = extract_species_and_parameters!(
363354
reactions, declared_syms; requiredec)
364355

356+
# Reads equations (and infers potential variables).
357+
# Excludes any parameters already extracted (if they also was a variable).
358+
declared_syms = union(declared_syms, species_extracted)
359+
vars_extracted, add_default_diff, equations = read_equations_options(
360+
options, declared_syms, parameters_extracted; requiredec)
361+
variables = vcat(variables_declared, vars_extracted)
362+
parameters_extracted = setdiff(parameters_extracted, vars_extracted)
363+
364+
# Creates the finalised parameter and species lists.
365365
species = vcat(species_declared, species_extracted)
366366
parameters = vcat(parameters_declared, parameters_extracted)
367367

368368
# Create differential expression.
369369
diffexpr = create_differential_expr(
370370
options, add_default_diff, [species; parameters; variables], tiv)
371371

372+
# Reads observables.
373+
observed_vars, observed_eqs, obs_syms = read_observed_options(
374+
options, [species_declared; variables], all_ivs; requiredec)
375+
372376
# Checks for input errors.
373377
(sum(length.([reaction_lines, option_lines])) != length(ex.args)) &&
374378
error("@reaction_network input contain $(length(ex.args) - sum(length.([reaction_lines,option_lines]))) malformed lines.")
@@ -701,7 +705,7 @@ end
701705
# `vars_extracted`: A vector with extracted variables (lhs in pure differential equations only).
702706
# `dtexpr`: If a differential equation is defined, the default derivative (D ~ Differential(t)) must be defined.
703707
# `equations`: a vector with the equations provided.
704-
function read_equations_options(options, variables_declared; requiredec = false)
708+
function read_equations_options(options, syms_declared, parameters_extracted; requiredec = false)
705709
# Prepares the equations. First, extracts equations from provided option (converting to block form if required).
706710
# Next, uses MTK's `parse_equations!` function to split input into a vector with the equations.
707711
eqs_input = haskey(options, :equations) ? options[:equations].args[3] : :(begin end)
@@ -713,34 +717,40 @@ function read_equations_options(options, variables_declared; requiredec = false)
713717
# Loops through all equations, checks for lhs of the form `D(X) ~ ...`.
714718
# When this is the case, the variable X and differential D are extracted (for automatic declaration).
715719
# Also performs simple error checks.
716-
vars_extracted = Vector{Symbol}()
720+
vars_extracted = OrderedSet{Union{Symbol, Expr}}()
717721
add_default_diff = false
718722
for eq in equations
719723
if (eq.head != :call) || (eq.args[1] != :~)
720724
error("Malformed equation: \"$eq\". Equation's left hand and right hand sides should be separated by a \"~\".")
721725
end
722726

723-
# Checks if the equation have the format D(X) ~ ... (where X is a symbol). This means that the
724-
# default differential has been used. X is added as a declared variable to the system, and
725-
# we make a note that a differential D = Differential(iv) should be made as well.
726-
lhs = eq.args[2]
727-
# if lhs: is an expression. Is a function call. The function's name is D. Calls a single symbol.
728-
if (lhs isa Expr) && (lhs.head == :call) && (lhs.args[1] == :D) &&
729-
(lhs.args[2] isa Symbol)
730-
diff_var = lhs.args[2]
731-
if in(diff_var, forbidden_symbols_error)
732-
error("A forbidden symbol ($(diff_var)) was used as an variable in this differential equation: $eq")
733-
elseif (!in(diff_var, variables_declared)) && requiredec
734-
throw(UndeclaredSymbolicError(
735-
"Unrecognized symbol $(diff_var) was used as a variable in an equation: \"$eq\". Since the @require_declaration flag is set, all variables in equations must be explicitly declared via @variables, @species, or @parameters."))
736-
else
737-
add_default_diff = true
738-
in(diff_var, variables_declared) || push!(vars_extracted, diff_var)
739-
end
727+
# If the default differential (`D`) is used, record that it should be decalred later on.
728+
if (:D union(syms_declared, parameters_extracted)) && find_D_call(eq)
729+
requiredec && throw(UndeclaredSymbolicError(
730+
"Unrecognized symbol D was used as a differential in an equation: \"$eq\". Since the @require_declaration flag is set, all differentials in equations must be explicitly declared using the @differentials option."))
731+
add_default_diff = true
732+
push!(syms_declared, :D)
740733
end
734+
735+
# Any undecalred symbolic variables encountered should be extracted as variables.
736+
add_syms_from_expr!(vars_extracted, eq, syms_declared)
737+
(!isempty(vars_extracted) && requiredec) && throw(UndeclaredSymbolicError(
738+
"Unrecognized symbolic variables $(join(vars_extracted, ", ")) detected in equation expression: \"$(string(eq))\". Since the flag @require_declaration is declared, all symbolic variables must be explicitly declared with the @species, @variables, and @parameters options."))
741739
end
742740

743-
return vars_extracted, add_default_diff, equations
741+
return collect(vars_extracted), add_default_diff, equations
742+
end
743+
744+
# Searches an expresion `expr` and returns true if it have any subexpression `D(...)` (where `...` can be anything).
745+
# Used to determine whether the default differential D has been used in any equation provided to `@equations`.
746+
function find_D_call(expr)
747+
return if Base.isexpr(expr, :call) && expr.args[1] == :D
748+
true
749+
elseif expr isa Expr
750+
any(find_D_call, expr.args)
751+
else
752+
false
753+
end
744754
end
745755

746756
# Creates an expression declaring differentials. Here, `tiv` is the time independent variables,

src/reactionsystem.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ Base.@kwdef mutable struct NetworkProperties{I <: Integer, V <: BasicSymbolic{Re
9797
stronglinkageclasses::Vector{Vector{Int}} = Vector{Vector{Int}}(undef, 0)
9898
terminallinkageclasses::Vector{Vector{Int}} = Vector{Vector{Int}}(undef, 0)
9999

100-
checkedrobust::Bool = false
100+
checkedrobust::Bool = false
101101
robustspecies::Vector{Int} = Vector{Int}(undef, 0)
102-
deficiency::Int = -1
102+
deficiency::Int = -1
103103
end
104104
#! format: on
105105

@@ -215,11 +215,11 @@ end
215215

216216
### ReactionSystem Structure ###
217217

218-
"""
218+
"""
219219
WARNING!!!
220220
221-
The following variable is used to check that code that should be updated when the `ReactionSystem`
222-
fields are updated has in fact been updated. Do not just blindly update this without first checking
221+
The following variable is used to check that code that should be updated when the `ReactionSystem`
222+
fields are updated has in fact been updated. Do not just blindly update this without first checking
223223
all such code and updating it appropriately (e.g. serialization). Please use a search for
224224
`reactionsystem_fields` throughout the package to ensure all places which should be updated, are updated.
225225
"""
@@ -318,7 +318,7 @@ struct ReactionSystem{V <: NetworkProperties} <:
318318
"""
319319
discrete_events::Vector{MT.SymbolicDiscreteCallback}
320320
"""
321-
Metadata for the system, to be used by downstream packages.
321+
Metadata for the system, to be used by downstream packages.
322322
"""
323323
metadata::Any
324324
"""
@@ -480,10 +480,10 @@ function ReactionSystem(iv; kwargs...)
480480
ReactionSystem(Reaction[], iv, [], []; kwargs...)
481481
end
482482

483-
# Called internally (whether DSL-based or programmatic model creation is used).
483+
# Called internally (whether DSL-based or programmatic model creation is used).
484484
# Creates a sorted reactions + equations vector, also ensuring reaction is first in this vector.
485-
# Extracts potential species, variables, and parameters from the input (if not provided as part of
486-
# the model creation) and creates the corresponding vectors.
485+
# Extracts potential species, variables, and parameters from the input (if not provided as part of
486+
# the model creation) and creates the corresponding vectors.
487487
# While species are ordered before variables in the unknowns vector, this ordering is not imposed here,
488488
# but carried out at a later stage.
489489
function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, us_in, ps_in;
@@ -495,7 +495,7 @@ function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, us_in, ps_in;
495495
any(in(obs_vars), us_in) &&
496496
error("Found an observable in the list of unknowns. This is not allowed.")
497497

498-
# Creates a combined iv vector (iv and sivs). This is used later in the function (so that
498+
# Creates a combined iv vector (iv and sivs). This is used later in the function (so that
499499
# independent variables can be excluded when encountered quantities are added to `us` and `ps`).
500500
t = value(iv)
501501
ivs = Set([t])
@@ -560,7 +560,7 @@ function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, us_in, ps_in;
560560
end
561561
psv = collect(new_ps)
562562

563-
# Passes the processed input into the next `ReactionSystem` call.
563+
# Passes the processed input into the next `ReactionSystem` call.
564564
ReactionSystem(fulleqs, t, usv, psv; spatial_ivs, continuous_events,
565565
discrete_events, observed, kwargs...)
566566
end
@@ -1062,8 +1062,8 @@ end
10621062

10631063
### General `ReactionSystem`-specific Functions ###
10641064

1065-
# Checks if the `ReactionSystem` structure have been updated without also updating the
1066-
# `reactionsystem_fields` constant. If this is the case, returns `false`. This is used in
1065+
# Checks if the `ReactionSystem` structure have been updated without also updating the
1066+
# `reactionsystem_fields` constant. If this is the case, returns `false`. This is used in
10671067
# certain functionalities which would break if the `ReactionSystem` structure is updated without
10681068
# also updating these functionalities.
10691069
function reactionsystem_uptodate_check()
@@ -1241,7 +1241,7 @@ end
12411241
### `ReactionSystem` Remaking ###
12421242

12431243
"""
1244-
remake_ReactionSystem_internal(rs::ReactionSystem;
1244+
remake_ReactionSystem_internal(rs::ReactionSystem;
12451245
default_reaction_metadata::Vector{Pair{Symbol, T}} = Vector{Pair{Symbol, Any}}()) where {T}
12461246
12471247
Takes a `ReactionSystem` and remakes it, returning a modified `ReactionSystem`. Modifications depend
@@ -1274,7 +1274,7 @@ function set_default_metadata(rs::ReactionSystem; default_reaction_metadata = []
12741274
# Currently, `noise_scaling` is the only relevant metadata supported this way.
12751275
drm_dict = Dict(default_reaction_metadata)
12761276
if haskey(drm_dict, :noise_scaling)
1277-
# Finds parameters, species, and variables in the noise scaling term.
1277+
# Finds parameters, species, and variables in the noise scaling term.
12781278
ns_expr = drm_dict[:noise_scaling]
12791279
ns_syms = [Symbolics.unwrap(sym) for sym in get_variables(ns_expr)]
12801280
ns_ps = Iterators.filter(ModelingToolkit.isparameter, ns_syms)
@@ -1414,7 +1414,7 @@ function ModelingToolkit.compose(sys::ReactionSystem, systems::AbstractArray; na
14141414
MT.collect_scoped_vars!(newunknowns, newparams, ssys, iv)
14151415
end
14161416

1417-
if !isempty(newunknowns)
1417+
if !isempty(newunknowns)
14181418
@set! sys.unknowns = union(get_unknowns(sys), newunknowns)
14191419
sort!(get_unknowns(sys), by = !isspecies)
14201420
@set! sys.species = filter(isspecies, get_unknowns(sys))

0 commit comments

Comments
 (0)