Skip to content

Commit 3a7bf00

Browse files
committed
Check the validity of DEs more carefully
1 parent f083800 commit 3a7bf00

File tree

4 files changed

+36
-10
lines changed

4 files changed

+36
-10
lines changed

src/structural_transformation/codegen.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,8 +524,11 @@ function ODAEProblem{iip}(
524524
parammap = DiffEqBase.NullParameters();
525525
callback = nothing,
526526
use_union = false,
527+
check = true,
527528
kwargs...
528529
) where {iip}
530+
eqs = equations(sys)
531+
check && ModelingToolkit.check_operator_variables(eqs, Differential)
529532
fun, dvs = build_torn_function(sys; kwargs...)
530533
ps = parameters(sys)
531534
defs = defaults(sys)
@@ -535,7 +538,7 @@ function ODAEProblem{iip}(
535538
u0 = ModelingToolkit.varmap_to_vars(u0map, dvs; defaults=defs, tofloat=true)
536539
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults=defs, tofloat=!use_union, use_union)
537540

538-
has_difference = any(isdifferenceeq, equations(sys))
541+
has_difference = any(isdifferenceeq, eqs)
539542
if has_continuous_events(sys)
540543
event_cb = generate_rootfinding_callback(sys; kwargs...)
541544
else

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ function generate_dae_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = p
9595

9696
end
9797

98-
check_derivative_variables(eq) = check_operator_variables(eq, Differential)
99-
10098
function generate_function(
10199
sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
102100
implicit_dae=false,
@@ -107,7 +105,7 @@ function generate_function(
107105

108106
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
109107
if !implicit_dae
110-
foreach(check_derivative_variables, eqs)
108+
check_operator_variables(eqs, Differential)
111109
end
112110
check_lhs(eqs, Differential, Set(dvs))
113111
# substitute x(t) by just x
@@ -130,7 +128,7 @@ end
130128

131129
function generate_difference_cb(sys::ODESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
132130
eqs = equations(sys)
133-
foreach(check_difference_variables, eqs)
131+
check_operator_variables(eqs, Difference)
134132

135133
var2eq = Dict(arguments(eq.lhs)[1] => eq for eq in eqs if isdifference(eq.lhs))
136134

src/systems/discrete_system/discrete_system.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,14 +243,12 @@ function get_delay_val(iv, x)
243243
return -delay
244244
end
245245

246-
check_difference_variables(eq) = check_operator_variables(eq, Difference)
247-
248246
function generate_function(
249247
sys::DiscreteSystem, dvs = states(sys), ps = parameters(sys);
250248
kwargs...
251249
)
252250
eqs = equations(sys)
253-
foreach(check_difference_variables, eqs)
251+
check_operator_variables(eqs, Difference)
254252
rhss = [eq.rhs for eq in eqs]
255253

256254
u = map(x->time_varying_as_func(value(x), sys), dvs)

src/utils.jl

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,39 @@ end
233233
end
234234

235235
"Check if difference/derivative operation occurs in the R.H.S. of an equation"
236-
function check_operator_variables(eq, op::Type, expr=eq.rhs)
236+
function _check_operator_variables(eq, op::T, expr=eq.rhs) where T
237237
istree(expr) || return nothing
238238
if operation(expr) isa op
239239
throw_invalid_operator(expr, eq, op)
240240
end
241-
foreach(expr -> check_operator_variables(eq, op, expr), SymbolicUtils.unsorted_arguments(expr))
241+
foreach(expr -> _check_operator_variables(eq, op, expr), SymbolicUtils.unsorted_arguments(expr))
242+
end
243+
"Check if all the LHS are unique"
244+
function check_operator_variables(eqs, op::T) where T
245+
ops = Set()
246+
tmp = Set()
247+
for eq in eqs
248+
_check_operator_variables(eq, op)
249+
vars!(tmp, eq.lhs)
250+
if length(tmp) == 1
251+
x = only(tmp)
252+
if op === Differential
253+
# Having a differece is fine for ODEs
254+
is_tmp_fine = isdifferential(x) || isdifference(x)
255+
else
256+
is_tmp_fine = istree(x) && !(operation(x) isa op)
257+
end
258+
else
259+
nd = count(x->istree(x) && !(operation(x) isa op), tmp)
260+
is_tmp_fine = iszero(nd)
261+
end
262+
empty!(tmp)
263+
is_tmp_fine || error("The LHS cannot contain nondifferentiated variables. Please run `structural_simplify` or use the DAE form.\nGot $eq")
264+
for v in tmp
265+
v in ops && error("The LHS operator must be unique. Please run `structural_simplify` or use the DAE form. $v appears in LHS more than once.")
266+
push!(ops, v)
267+
end
268+
end
242269
end
243270

244271
isoperator(expr, op) = istree(expr) && operation(expr) isa op

0 commit comments

Comments
 (0)