Skip to content

Commit da50bf8

Browse files
committed
add update to read equations function back in
1 parent 7d6aa95 commit da50bf8

File tree

1 file changed

+25
-21
lines changed

1 file changed

+25
-21
lines changed

src/dsl.jl

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,7 @@ end
702702
# `vars_extracted`: A vector with extracted variables (lhs in pure differential equations only).
703703
# `dtexpr`: If a differential equation is defined, the default derivative (D ~ Differential(t)) must be defined.
704704
# `equations`: a vector with the equations provided.
705-
function read_equations_options(options, variables_declared; requiredec = false)
705+
function read_equations_options(options, syms_declared; requiredec = false)
706706
# Prepares the equations. First, extracts equations from provided option (converting to block form if required).
707707
# Next, uses MTK's `parse_equations!` function to split input into a vector with the equations.
708708
eqs_input = haskey(options, :equations) ? options[:equations].args[3] : :(begin end)
@@ -722,35 +722,39 @@ function read_equations_options(options, variables_declared; requiredec = false)
722722
error("Malformed equation: \"$eq\". Equation's left hand and right hand sides should be separated by a \"~\".")
723723
end
724724

725-
# Checks if the equation have the format D(X) ~ ... (where X is a symbol). This means that the
726-
# default differential has been used and we make a note that it should be decalred in the DSL output.
727-
lhs = eq.args[2]
728-
# If lhs: is an expression. Is a function call. The function's name is D. It has a single argument.
729-
if (lhs isa Expr) && (lhs.head == :call) && (lhs.args[1] == :D) &&
730-
(lhs.args[2] isa Symbol)
731-
diff_var = lhs.args[2]
732-
if in(diff_var, forbidden_symbols_error)
733-
error("A forbidden symbol ($(diff_var)) was used as an variable in this differential equation: $eq")
734-
elseif (!in(diff_var, variables_declared)) && requiredec
735-
throw(UndeclaredSymbolicError(
736-
"Unrecognized symbol $(diff_var) was used as a variable in an equation: \"$eq\". Since the @require_declaration flag is set, all variables in equations must be explicitly declared via @variables, @species, or @parameters."))
737-
else
738-
add_default_diff = true
739-
in(diff_var, variables_declared) || push!(vars_extracted, diff_var)
740-
end
741-
if !add_default_diff
742-
add_default_diff = true
743-
excluded_syms = [excluded_syms; :D]
744-
end
725+
# If the default differential (`D`) is used, record that it should be decalred later on.
726+
if !in(eq, excluded_syms) && find_D_call(eq)
727+
requiredec && throw(UndeclaredSymbolicError(
728+
"Unrecognized symbol D was used as a differential in an equation: \"$eq\". Since the @require_declaration flag is set, all differentials in equations must be explicitly declared using the @differentials option."))
729+
add_default_diff = true
730+
excluded_syms = [excluded_syms; :D]
745731
end
746732

747733
# Any undecalred symbolic variables encountered should be extracted as variables.
734+
# Additional step required to handle `requiredec = true` (to be improved later).
735+
prev_vars_extracted = deepcopy(vars_extracted)
748736
add_syms_from_expr!(vars_extracted, eq, excluded_syms)
737+
if requiredec && length(prev_vars_extracted) < length(vars_extracted)
738+
throw(UndeclaredSymbolicError(
739+
"Unrecognized symbols $(setdiff(vars_extracted, prev_vars_extracted)) was used in an equation: \"$eq\". Since the flag @require_declaration is set, all variables must be declared with the @species, @parameters, or @variables macros."))
740+
end
749741
end
750742

751743
return collect(vars_extracted), add_default_diff, equations
752744
end
753745

746+
# Searches an expresion `expr` and returns true if it have any subexpression `D(...)` (where `...` can be anything).
747+
# Used to determine whether the default differential D has been used in any equation provided to `@equations`.
748+
function find_D_call(expr)
749+
return if Base.isexpr(expr, :call) && expr.args[1] == :D
750+
true
751+
elseif expr isa Expr
752+
any(find_D_call, expr.args)
753+
else
754+
false
755+
end
756+
end
757+
754758
# Creates an expression declaring differentials. Here, `tiv` is the time independent variables,
755759
# which is used by the default differential (if it is used).
756760
function create_differential_expr(options, add_default_diff, used_syms, tiv)

0 commit comments

Comments
 (0)