Skip to content

Commit 470610e

Browse files
authored
Merge pull request #394 from SciML/myb/autosys
Automatic state detection and some performance optimization
2 parents fff4d5c + 585028d commit 470610e

File tree

5 files changed

+67
-13
lines changed

5 files changed

+67
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "3.6.2"
4+
version = "3.6.3"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Latexify, Unitful, ArrayInterface
66
using MacroTools
77
using UnPack: @unpack
88
using DiffEqJump
9+
using DataStructures: OrderedDict, OrderedSet
910

1011
using Base.Threads
1112
import MacroTools: splitdef, combinedef, postwalk, striplines

src/systems/diffeqs/first_order_transform.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using DataStructures: OrderedDict
21
function lower_varname(var::Variable, idv, order)
32
order == 0 && return var
43
name = Symbol(var.name, , string(idv.name)^order)

src/systems/diffeqs/odesystem.jl

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,22 +74,50 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
7474
ODESystem(deqs, iv′, dvs′, ps′, tgrad, jac, Wfact, Wfact_t, name, systems)
7575
end
7676

77-
var_from_nested_derivative(x) = var_from_nested_derivative(x,0)
7877
var_from_nested_derivative(x::Constant) = (missing, missing)
79-
var_from_nested_derivative(x,i) = x.op isa Differential ? var_from_nested_derivative(x.args[1],i+1) : (x.op,i)
78+
var_from_nested_derivative(x,i=0) = x.op isa Differential ? var_from_nested_derivative(x.args[1],i+1) : (x.op,i)
79+
8080
iv_from_nested_derivative(x) = x.op isa Differential ? iv_from_nested_derivative(x.args[1]) : x.args[1].op
8181
iv_from_nested_derivative(x::Constant) = missing
8282

8383
function ODESystem(eqs; kwargs...)
84-
ivs = unique(skipmissing(iv_from_nested_derivative(eq.lhs) for eq eqs))
85-
length(ivs) == 1 || throw(ArgumentError("one independent variable currently supported"))
86-
iv = first(ivs)
87-
88-
dvs = unique(skipmissing(var_from_nested_derivative(eq.lhs)[1] for eq eqs))
89-
ps = filter(vars(eq.rhs for eq eqs)) do x
90-
isparameter(x) & !isequal(x, iv)
91-
end |> collect
92-
ODESystem(eqs, iv, dvs, ps; kwargs...)
84+
# NOTE: this assumes that the order of algebric equations doesn't matter
85+
diffvars = OrderedSet{Variable}()
86+
allstates = OrderedSet{Variable}()
87+
ps = OrderedSet{Variable}()
88+
# reorder equations such that it is in the form of `diffeq, algeeq`
89+
diffeq = Equation[]
90+
algeeq = Equation[]
91+
# initial loop for finding `iv`
92+
iv = nothing
93+
for eq in eqs
94+
if !(eq.lhs isa Constant) # assume eq.lhs is either Differential or Constant
95+
iv = iv_from_nested_derivative(eq.lhs)
96+
end
97+
end
98+
iv === nothing && throw(ArgumentError("No differential variable detected."))
99+
for eq in eqs
100+
for var in vars(eq.rhs for eq eqs)
101+
var isa Variable || continue
102+
if isparameter(var)
103+
isequal(var, iv) || push!(ps, var)
104+
else
105+
push!(allstates, var)
106+
end
107+
end
108+
if eq.lhs isa Constant
109+
push!(algeeq, eq)
110+
else
111+
diffvar = first(var_from_nested_derivative(eq.lhs))
112+
iv == iv_from_nested_derivative(eq.lhs) || throw(ArgumentError("An ODESystem can only have one independent variable."))
113+
diffvar in diffvars && throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
114+
push!(diffvars, diffvar)
115+
push!(diffeq, eq)
116+
end
117+
end
118+
algevars = setdiff(allstates, diffvars)
119+
# the orders here are very important!
120+
return ODESystem(append!(diffeq, algeeq), iv, vcat(collect(diffvars), collect(algevars)), ps; kwargs...)
93121
end
94122

95123
Base.:(==)(sys1::ODESystem, sys2::ODESystem) =

test/odesystem.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ModelingToolkit, StaticArrays, LinearAlgebra
2+
using OrdinaryDiffEq
23
using DiffEqBase
34
using Test
45

@@ -192,3 +193,28 @@ prob = ODEProblem(lotka,[1.0,1.0],(0.0,1.0),[1.5,1.0,3.0,1.0])
192193

193194
de = modelingtoolkitize(prob)
194195
ODEFunction(de)(similar(prob.u0), prob.u0, prob.p, 0.1)
196+
197+
# automatic state detection for DAEs
198+
@parameters t k₁ k₂ k₃
199+
@variables y₁(t) y₂(t) y₃(t)
200+
@derivatives D'~t
201+
# reorder the system just to be a little spicier
202+
eqs = [D(y₁) ~ -k₁*y₁+k₃*y₂*y₃,
203+
0 ~ y₁ + y₂ + y₃ - 1,
204+
D(y₂) ~ k₁*y₁-k₂*y₂^2-k₃*y₂*y₃]
205+
sys = ODESystem(eqs)
206+
u0 = [y₁ => 1.0,
207+
y₂ => 0.0,
208+
y₃ => 0.0]
209+
p = [k₁ => 0.04,
210+
k₂ => 3e7,
211+
k₃ => 1e4]
212+
tspan = (0.0,100000.0)
213+
prob1 = ODEProblem(sys,u0,tspan,p)
214+
prob2 = ODEProblem(sys,u0,tspan,p,jac=true)
215+
# Wfact version is not very stable because of the lack of pivoting
216+
prob3 = ODEProblem(sys,u0,tspan,p,Wfact=true,Wfact_t=true)
217+
for (prob, atol) in [(prob1, 1e-12), (prob2, 1e-12), (prob3, 0.1)]
218+
sol = solve(prob, Rodas5())
219+
@test all(x->(sum(x), 1.0, atol=atol), sol.u)
220+
end

0 commit comments

Comments
 (0)