Skip to content

Commit 4c518e6

Browse files
lamortonYingboMa
andauthored
Assert that ODE derivatives are w.r.t. the independent variable (#1076)
* Check that all the derivatives in a time-dependent system (other than PDE) are derivatives with respect to the independent variable. * Update src/utils.jl Nicer way to do this. Co-authored-by: Yingbo Ma <[email protected]> * Revert accidental inclusion. * Set is faster, no need to use OrderedSet. * Style improvements. * Fix warning about local assignment. * Remove checks on differentials for systems that don't support differentials. Co-authored-by: Yingbo Ma <[email protected]>
1 parent 3a0f16a commit 4c518e6

File tree

13 files changed

+77
-32
lines changed

13 files changed

+77
-32
lines changed

src/systems/control/controlsystem.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,12 @@ struct ControlSystem <: AbstractControlSystem
7272
parameters are not supplied in `ODEProblem`.
7373
"""
7474
defaults::Dict
75-
function ControlSystem(loss, deqs, iv, dvs, controls,ps, observed, name, systems, defaults)
76-
check_variables(dvs,iv)
77-
check_parameters(ps,iv)
78-
new(loss, deqs, iv, dvs, controls,ps, observed, name, systems, defaults)
75+
function ControlSystem(loss, deqs, iv, dvs, controls, ps, observed, name, systems, defaults)
76+
check_variables(dvs, iv)
77+
check_parameters(ps, iv)
78+
check_equations(deqs, iv)
79+
check_equations(observed, iv)
80+
new(loss, deqs, iv, dvs, controls, ps, observed, name, systems, defaults)
7981
end
8082
end
8183

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ struct ODESystem <: AbstractODESystem
7777
function ODESystem(deqs, iv, dvs, ps, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
7878
check_variables(dvs,iv)
7979
check_parameters(ps,iv)
80+
check_equations(deqs,iv)
8081
new(deqs, iv, dvs, ps, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
8182
end
8283
end

src/systems/diffeqs/sdesystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ struct SDESystem <: AbstractODESystem
7979
function SDESystem(deqs, neqs, iv, dvs, ps, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
8080
check_variables(dvs,iv)
8181
check_parameters(ps,iv)
82+
check_equations(deqs,iv)
8283
new(deqs, neqs, iv, dvs, ps, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
8384
end
8485
end

src/systems/discrete_system/discrete_system.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ struct DiscreteSystem <: AbstractSystem
5050
"""
5151
default_p::Dict
5252
function DiscreteSystem(discreteEqs, iv, dvs, ps, observed, name, systems, default_u0, default_p)
53-
check_variables(dvs,iv)
54-
check_parameters(ps,iv)
53+
check_variables(dvs, iv)
54+
check_parameters(ps, iv)
5555
new(discreteEqs, iv, dvs, ps, observed, name, systems, default_u0, default_p)
5656
end
5757
end

src/systems/jumps/jumpsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractSystem
5252
"""
5353
connection_type::Any
5454
function JumpSystem{U}(ap::U, iv, states, ps, observed, name, systems, defaults, connection_type) where U <: ArrayPartition
55-
check_variables(states,iv)
56-
check_parameters(ps,iv)
55+
check_variables(states, iv)
56+
check_parameters(ps, iv)
5757
new{U}(ap, iv, states, ps, observed, name, systems, defaults, connection_type)
5858
end
5959
end

src/systems/reaction/reactionsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ struct ReactionSystem <: AbstractSystem
154154
iv′ = value(iv)
155155
states′ = value.(states)
156156
ps′ = value.(ps)
157-
check_variables(states′,iv′)
158-
check_parameters(ps′,iv′)
157+
check_variables(states′, iv′)
158+
check_parameters(ps′, iv′)
159159
new(eqs, iv′, states′, ps′, observed, name, systems)
160160
end
161161
end

src/utils.jl

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,53 @@ function _readable_code(ex)
105105
end
106106
readable_code(expr) = JuliaFormatter.format_text(string(Base.remove_linenums!(_readable_code(expr))))
107107

108-
function check_parameters(ps,iv)
108+
function check_parameters(ps, iv)
109109
for p in ps
110-
isequal(iv,p) && throw(ArgumentError("Independent variable $iv not allowed in parameters."))
110+
isequal(iv, p) && throw(ArgumentError("Independent variable $iv not allowed in parameters."))
111111
end
112112
end
113113

114-
function check_variables(dvs,iv)
114+
function check_variables(dvs, iv)
115115
for dv in dvs
116-
isequal(iv,dv) && throw(ArgumentError("Independent variable $iv not allowed in dependent variables."))
116+
isequal(iv, dv) && throw(ArgumentError("Independent variable $iv not allowed in dependent variables."))
117117
isequal(iv, iv_from_nested_derivative(dv)) || throw(ArgumentError("Variable $dv is not a function of independent variable $iv."))
118118
end
119119
end
120120

121+
"Get all the independent variables with respect to which differentials are taken."
122+
function collect_differentials(eqs)
123+
vars = Set()
124+
ivs = Set()
125+
for eq in eqs
126+
vars!(vars, eq)
127+
for v in vars
128+
isdifferential(v) || continue
129+
collect_ivs_from_nested_differential!(ivs, v)
130+
end
131+
empty!(vars)
132+
end
133+
return ivs
134+
end
135+
136+
"Assert that equations are well-formed when building ODE."
137+
function check_equations(eqs, iv)
138+
ivs = collect_differentials(eqs)
139+
display = collect(ivs)
140+
length(ivs) <= 1 || throw(ArgumentError("Differential w.r.t. multiple variables $display are not allowed."))
141+
if length(ivs) == 1
142+
single_iv = pop!(ivs)
143+
isequal(single_iv, iv) || throw(ArgumentError("Differential w.r.t. variable ($single_iv) other than the independent variable ($iv) are not allowed."))
144+
end
145+
end
146+
"Get all the independent variables with respect to which differentials are taken."
147+
function collect_ivs_from_nested_differential!(ivs, x::Term)
148+
op = operation(x)
149+
if op isa Differential
150+
push!(ivs, op.x)
151+
collect_ivs_from_nested_differential!(ivs, arguments(x)[1])
152+
end
153+
end
154+
121155
iv_from_nested_derivative(x::Term) = operation(x) isa Differential ? iv_from_nested_derivative(arguments(x)[1]) : arguments(x)[1]
122156
iv_from_nested_derivative(x::Sym) = x
123157
iv_from_nested_derivative(x) = missing

test/controlsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ sol = solve(prob,BFGS())
2525
D(x) ~ - p[2]*x
2626
D(v) ~ p[1]*u^3
2727
]
28-
sys1 = ControlSystem(loss,eqs_short,t,[x,v],[u],p,name=:sys1)
29-
sys2 = ControlSystem(loss,eqs_short,t,[x,v],[u],p,name=:sys1)
30-
@test_throws ArgumentError ControlSystem(loss,[sys2.v ~ sys1.v],t, [],[],[],systems=[sys1, sys2])
28+
sys1 = ControlSystem(loss,eqs_short, t, [x, v], [u], p, name = :sys1)
29+
sys2 = ControlSystem(loss,eqs_short, t, [x, v], [u], p, name = :sys1)
30+
@test_throws ArgumentError ControlSystem(loss, [sys2.v ~ sys1.v], t, [], [], [], systems = [sys1, sys2])
3131
end

test/jumpsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ sol = solve(jprob, SSAStepper());
155155

156156
# issue #819
157157
@testset "Combined system name collisions" begin
158-
sys1 = JumpSystem([maj1,maj2], t, [S], [β,γ],name=:sys1)
159-
sys2 = JumpSystem([maj1,maj2], t, [S], [β,γ],name=:sys1)
160-
@test_throws ArgumentError JumpSystem([sys1.γ ~ sys2.γ], t,[],[], systems=[sys1, sys2])
158+
sys1 = JumpSystem([maj1, maj2], t, [S], [β, γ], name = :sys1)
159+
sys2 = JumpSystem([maj1, maj2], t, [S], [β, γ], name = :sys1)
160+
@test_throws ArgumentError JumpSystem([sys1.γ ~ sys2.γ], t, [], [], systems = [sys1, sys2])
161161
end

test/nonlinearsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,13 @@ np = NonlinearProblem(ns, [0,0,0], [1,2,3], jac=true, sparse=true)
125125
@parameters a
126126
@variables x f
127127

128-
NonlinearSystem([0 ~ -a*x + f],[x,f],[a], name=name)
128+
NonlinearSystem([0 ~ -a * x + f], [x,f], [a], name = name)
129129
end
130130

131131
function issue819()
132132
sys1 = makesys(:sys1)
133133
sys2 = makesys(:sys1)
134-
@test_throws ArgumentError NonlinearSystem([sys2.f ~ sys1.x, sys1.f ~ 0],[],[], systems=[sys1, sys2])
134+
@test_throws ArgumentError NonlinearSystem([sys2.f ~ sys1.x, sys1.f ~ 0], [], [], systems = [sys1, sys2])
135135
end
136136
issue819()
137137
end

0 commit comments

Comments
 (0)