Skip to content

Commit 80d02c6

Browse files
committed
Allow x ~ y in ODESystem
1 parent 6671ae6 commit 80d02c6

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,28 +119,34 @@ function ODESystem(eqs, iv=nothing; kwargs...)
119119
end
120120
iv === nothing && throw(ArgumentError("Please pass in independent variables."))
121121
for eq in eqs
122-
for var in vars(eq.rhs for eq eqs)
123-
if isparameter(var) || isparameter(var.op)
124-
isequal(var, iv) || push!(ps, var)
125-
else
126-
push!(allstates, var)
127-
end
128-
end
129-
if !(eq.lhs isa Symbolic)
130-
push!(algeeq, eq)
131-
else
132-
diffvar = first(var_from_nested_derivative(eq.lhs))
122+
collect_vars!(allstates, ps, eq.lhs, iv)
123+
collect_vars!(allstates, ps, eq.rhs, iv)
124+
if isdiffeq(eq)
125+
diffvar, _ = var_from_nested_derivative(eq.lhs)
133126
isequal(iv, iv_from_nested_derivative(eq.lhs)) || throw(ArgumentError("An ODESystem can only have one independent variable."))
134127
diffvar in diffvars && throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
135128
push!(diffvars, diffvar)
136129
push!(diffeq, eq)
130+
else
131+
push!(algeeq, eq)
137132
end
138133
end
139134
algevars = setdiff(allstates, diffvars)
140135
# the orders here are very important!
141136
return ODESystem(append!(diffeq, algeeq), iv, vcat(collect(diffvars), collect(algevars)), ps; kwargs...)
142137
end
143138

139+
function collect_vars!(states, parameters, expr, iv)
140+
for var in vars(expr)
141+
if isparameter(var) || isparameter(var.op)
142+
isequal(var, iv) || push!(parameters, var)
143+
else
144+
push!(states, var)
145+
end
146+
end
147+
return nothing
148+
end
149+
144150
Base.:(==)(sys1::ODESystem, sys2::ODESystem) =
145151
_eq_unordered(sys1.eqs, sys2.eqs) && isequal(sys1.iv, sys2.iv) &&
146152
_eq_unordered(sys1.states, sys2.states) && _eq_unordered(sys1.ps, sys2.ps)

test/odesystem.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,14 @@ for (prob, atol) in [(prob1, 1e-12), (prob2, 1e-12), (prob3, 1e-12)]
200200
sol = solve(prob, Rodas5())
201201
@test all(x->(sum(x), 1.0, atol=atol), sol.u)
202202
end
203+
204+
@parameters t σ β
205+
@variables x(t) y(t) z(t)
206+
@derivatives D'~t
207+
eqs = [D(x) ~ σ*(y-x),
208+
D(y) ~ x-β*y,
209+
x + z ~ y]
210+
sys = ODESystem(eqs)
211+
@test all(isequal.(states(sys), [x, y, z]))
212+
@test all(isequal.(parameters(sys), [σ, β]))
213+
@test equations(sys) == eqs

0 commit comments

Comments
 (0)