Skip to content

Commit bd58adb

Browse files
Merge pull request #92 from JuliaDiffEq/hg/refactor/systems
Refactor system storage
2 parents 668add3 + 259a46f commit bd58adb

File tree

6 files changed

+91
-115
lines changed

6 files changed

+91
-115
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Each operation builds an `Operation` type, and thus `eqs` is an array of
4646
analyzed by other programs. We can turn this into a `DiffEqSystem` via:
4747

4848
```julia
49-
de = DiffEqSystem(eqs,[t],[x,y,z],[σ,ρ,β])
49+
de = DiffEqSystem(eqs,t,[x,y,z],[σ,ρ,β])
5050
de = DiffEqSystem(eqs)
5151
```
5252

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Base.convert(::Type{Variable},x::Int64) = Constant(x)
2020

2121
function caclulate_jacobian end
2222

23-
@enum FunctionVersions ArrayFunction=1 SArrayFunction=2
23+
@enum FunctionVersion ArrayFunction=1 SArrayFunction=2
2424

2525
include("operations.jl")
2626
include("differentials.jl")

src/equations.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
export Equation
22

33

4-
mutable struct Equation
4+
struct Equation
55
lhs::Expression
66
rhs::Expression
77
end
88
Base.broadcastable(eq::Equation) = Ref(eq)
9+
Base.:(==)(a::Equation, b::Equation) = (a.lhs, a.rhs) == (b.lhs, b.rhs)
910

1011
Base.:~(lhs::Expression, rhs::Expression) = Equation(lhs, rhs)
1112
Base.:~(lhs::Expression, rhs::Number ) = Equation(lhs, rhs)
1213
Base.:~(lhs::Number , rhs::Expression) = Equation(lhs, rhs)
1314

1415

1516
_is_dependent(x::Variable) = x.subtype === :Unknown && !isempty(x.dependents)
16-
_is_parameter(ivs) = x -> x.subtype === :Parameter && x ivs
17+
_is_parameter(iv) = x -> x.subtype === :Parameter && x iv
1718
_subtype(subtype::Symbol) = x -> x.subtype === subtype
1819

1920
function extract_elements(eqs, predicates)
@@ -29,10 +30,10 @@ function extract_elements(eqs, predicates)
2930
return result
3031
end
3132

33+
get_args(O::Operation) = O.args
34+
get_args(eq::Equation) = Expression[eq.lhs, eq.rhs]
3235
function vars!(vars, op)
33-
args = isa(op, Equation) ? Expression[op.lhs, op.rhs] : op.args
34-
35-
for arg args
36+
for arg get_args(op)
3637
if isa(arg, Operation)
3738
vars!(vars, arg)
3839
elseif isa(arg, Variable)

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,58 @@
1-
mutable struct DiffEqSystem <: AbstractSystem
2-
eqs::Vector{Equation}
3-
ivs::Vector{Variable}
1+
using Base: RefValue
2+
3+
4+
isintermediate(eq::Equation) = !(isa(eq.lhs, Operation) && isa(eq.lhs.op, Differential))
5+
6+
struct DiffEq # D(x) = t
7+
D::Differential # D
8+
var::Variable # x
9+
rhs::Expression # t
10+
end
11+
function Base.convert(::Type{DiffEq}, eq::Equation)
12+
isintermediate(eq) && throw(ArgumentError("intermediate equation received"))
13+
return DiffEq(eq.lhs.op, eq.lhs.args[1], eq.rhs)
14+
end
15+
Base.:(==)(a::DiffEq, b::DiffEq) = (a.D, a.var, a.rhs) == (b.D, b.var, b.rhs)
16+
get_args(eq::DiffEq) = Expression[eq.var, eq.rhs]
17+
18+
struct DiffEqSystem <: AbstractSystem
19+
eqs::Vector{DiffEq}
20+
iv::Variable
421
dvs::Vector{Variable}
522
ps::Vector{Variable}
6-
jac::Matrix{Expression}
7-
function DiffEqSystem(eqs, ivs, dvs, ps, jac)
8-
all(!isintermediate, eqs) ||
9-
throw(ArgumentError("no intermediate equations permitted in DiffEqSystem"))
10-
11-
new(eqs, ivs, dvs, ps, jac)
23+
jac::RefValue{Matrix{Expression}}
24+
function DiffEqSystem(eqs, iv, dvs, ps)
25+
jac = RefValue(Matrix{Expression}(undef, 0, 0))
26+
new(eqs, iv, dvs, ps, jac)
1227
end
1328
end
1429

15-
DiffEqSystem(eqs, ivs, dvs, ps) = DiffEqSystem(eqs, ivs, dvs, ps, Matrix{Expression}(undef,0,0))
16-
1730
function DiffEqSystem(eqs)
1831
dvs, = extract_elements(eqs, [_is_dependent])
1932
ivs = unique(vcat((dv.dependents for dv dvs)...))
20-
ps, = extract_elements(eqs, [_is_parameter(ivs)])
21-
DiffEqSystem(eqs, ivs, dvs, ps, Matrix{Expression}(undef,0,0))
33+
length(ivs) == 1 || throw(ArgumentError("one independent variable currently supported"))
34+
iv = first(ivs)
35+
ps, = extract_elements(eqs, [_is_parameter(iv)])
36+
DiffEqSystem(eqs, iv, dvs, ps)
2237
end
2338

24-
function DiffEqSystem(eqs, ivs)
25-
dvs, ps = extract_elements(eqs, [_is_dependent, _is_parameter(ivs)])
26-
DiffEqSystem(eqs, ivs, dvs, ps, Matrix{Expression}(undef,0,0))
39+
function DiffEqSystem(eqs, iv)
40+
dvs, ps = extract_elements(eqs, [_is_dependent, _is_parameter(iv)])
41+
DiffEqSystem(eqs, iv, dvs, ps)
2742
end
2843

29-
isintermediate(eq::Equation) = !(isa(eq.lhs, Operation) && isa(eq.lhs.op, Differential))
30-
3144

32-
function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
45+
function generate_ode_function(sys::DiffEqSystem; version::FunctionVersion = ArrayFunction)
3346
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
3447
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
3548
sys_exprs = build_equals_expr.(sys.eqs)
36-
if version == ArrayFunction
37-
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in eachindex(sys.dvs)]
49+
if version === ArrayFunction
50+
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.iv.name)"))) for i in eachindex(sys.dvs)]
3851
exprs = vcat(var_exprs,param_exprs,sys_exprs,dvar_exprs)
3952
block = expr_arr_to_block(exprs)
4053
:((du,u,p,t)->$(toexpr(block)))
41-
elseif version == SArrayFunction
42-
dvar_exprs = [:($(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in eachindex(sys.dvs)]
54+
elseif version === SArrayFunction
55+
dvar_exprs = [:($(Symbol("$(sys.dvs[i].name)_$(sys.iv.name)"))) for i in eachindex(sys.dvs)]
4356
svector_expr = quote
4457
E = eltype(tuple($(dvar_exprs...)))
4558
T = StaticArrays.similar_type(typeof(u), E)
@@ -51,26 +64,24 @@ function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
5164
end
5265
end
5366

54-
function build_equals_expr(eq::Equation)
55-
@assert !isintermediate(eq)
56-
57-
lhs = Symbol(eq.lhs.args[1].name, :_, eq.lhs.op.x.name)
67+
function build_equals_expr(eq::DiffEq)
68+
lhs = Symbol(eq.var.name, :_, eq.D.x.name)
5869
return :($lhs = $(convert(Expr, eq.rhs)))
5970
end
6071

6172
function calculate_jacobian(sys::DiffEqSystem, simplify=true)
73+
isempty(sys.jac[]) || return sys.jac[] # use cached Jacobian, if possible
6274
rhs = [eq.rhs for eq in sys.eqs]
6375

64-
sys_exprs = calculate_jacobian(rhs, sys.dvs)
65-
sys_exprs = Expression[expand_derivatives(expr) for expr in sys_exprs]
66-
sys_exprs
76+
jac = expand_derivatives.(calculate_jacobian(rhs, sys.dvs))
77+
sys.jac[] = jac # cache Jacobian
78+
return jac
6779
end
6880

6981
function generate_ode_jacobian(sys::DiffEqSystem, simplify=true)
7082
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
7183
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
7284
jac = calculate_jacobian(sys, simplify)
73-
sys.jac = jac
7485
jac_exprs = [:(J[$i,$j] = $(convert(Expr, jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
7586
exprs = vcat(var_exprs,param_exprs,vec(jac_exprs))
7687
block = expr_arr_to_block(exprs)
@@ -80,7 +91,7 @@ end
8091
function generate_ode_iW(sys::DiffEqSystem, simplify=true)
8192
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
8293
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
83-
jac = sys.jac
94+
jac = calculate_jacobian(sys, simplify)
8495

8596
gam = Parameter(:gam)
8697

@@ -109,12 +120,12 @@ function generate_ode_iW(sys::DiffEqSystem, simplify=true)
109120
:((iW,u,p,gam,t)->$(block)),:((iW,u,p,gam,t)->$(block2))
110121
end
111122

112-
function DiffEqBase.ODEFunction(sys::DiffEqSystem;version = ArrayFunction,kwargs...)
113-
expr = generate_ode_function(sys;version=version,kwargs...)
114-
if version == ArrayFunction
115-
ODEFunction{true}(eval(expr))
116-
elseif version == SArrayFunction
117-
ODEFunction{false}(eval(expr))
123+
function DiffEqBase.ODEFunction(sys::DiffEqSystem; version::FunctionVersion = ArrayFunction)
124+
expr = generate_ode_function(sys; version = version)
125+
if version === ArrayFunction
126+
ODEFunction{true}(eval(expr))
127+
elseif version === SArrayFunction
128+
ODEFunction{false}(eval(expr))
118129
end
119130
end
120131

Lines changed: 25 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,56 @@
1-
extract_idv(eq::Equation) = eq.lhs.op.x
1+
extract_idv(eq::DiffEq) = eq.D.x
22

3-
function lower_varname(O::Operation, naming_scheme; lower=false)
4-
@assert isa(O.op, Differential)
5-
6-
D, x = O.op, O.args[1]
3+
function lower_varname(D::Differential, x; lower=false)
74
order = lower ? D.order-1 : D.order
8-
9-
lower_varname(x, D.x, order, naming_scheme)
5+
return lower_varname(x, D.x, order)
106
end
11-
function lower_varname(var::Variable, idv, order::Int, naming_scheme)
7+
function lower_varname(var::Variable, idv, order::Int)
128
sym = var.name
13-
name = order == 0 ? sym : Symbol(sym, naming_scheme, string(idv.name)^order)
9+
name = order == 0 ? sym : Symbol(sym, :_, string(idv.name)^order)
1410
return Variable(name, var.subtype, var.dependents)
1511
end
1612

17-
function ode_order_lowering(sys::DiffEqSystem; kwargs...)
18-
eqs = sys.eqs
19-
ivs = sys.ivs
20-
eqs_lowered = ode_order_lowering(eqs; kwargs...)
21-
DiffEqSystem(eqs_lowered, ivs)
13+
function ode_order_lowering(sys::DiffEqSystem)
14+
eqs_lowered = ode_order_lowering(sys.eqs, sys.iv)
15+
DiffEqSystem(eqs_lowered, sys.iv)
2216
end
23-
ode_order_lowering(eqs; naming_scheme = "_") = ode_order_lowering!(deepcopy(eqs), naming_scheme)
24-
function ode_order_lowering!(eqs, naming_scheme)
25-
idv = extract_idv(eqs[1])
26-
D = Differential(idv, 1)
17+
function ode_order_lowering(eqs, iv)
18+
D = Differential(iv, 1)
2719
var_order = Dict{Variable,Int}()
2820
vars = Variable[]
29-
dv_name = eqs[1].lhs.args[1].subtype
21+
new_eqs = similar(eqs, DiffEq)
3022

31-
for eq in eqs
23+
for (i, eq) enumerate(eqs)
3224
var, maxorder = extract_var_order(eq)
3325
maxorder == 1 && continue # fast pass
3426
if maxorder > get(var_order, var, 0)
3527
var_order[var] = maxorder
3628
var vars || push!(vars, var)
3729
end
38-
lhs_renaming!(eq, D, naming_scheme)
39-
rhs_renaming!(eq, naming_scheme)
30+
var′ = lower_varname(eq.D, eq.var, lower = true)
31+
rhs′ = rename(eq.rhs)
32+
new_eqs[i] = DiffEq(D, var′, rhs′)
4033
end
4134

4235
for var vars
4336
order = var_order[var]
4437
for o in (order-1):-1:1
45-
lhs = D(lower_varname(var, idv, o-1, naming_scheme))
46-
rhs = lower_varname(var, idv, o, naming_scheme)
47-
eq = Equation(lhs, rhs)
48-
push!(eqs, eq)
38+
lvar = lower_varname(var, iv, o-1)
39+
rhs = lower_varname(var, iv, o)
40+
eq = DiffEq(D, lvar, rhs)
41+
push!(new_eqs, eq)
4942
end
5043
end
5144

52-
return eqs
45+
return new_eqs
5346
end
5447

55-
function lhs_renaming!(eq, D, naming_scheme)
56-
eq.lhs = D(lower_varname(eq.lhs, naming_scheme, lower=true))
57-
return eq
48+
function rename(O::Expression)
49+
isa(O, Operation) || return O
50+
isa(O.op, Differential) && return lower_varname(O.op, O.args[1])
51+
return Operation(O.op, rename.(O.args))
5852
end
59-
rhs_renaming!(eq, naming_scheme) = _rec_renaming!(eq.rhs, naming_scheme)
6053

61-
function _rec_renaming!(rhs, naming_scheme)
62-
isa(rhs, Operation) && isa(rhs.op, Differential) && return lower_varname(rhs, naming_scheme)
63-
if rhs isa Operation
64-
args = rhs.args
65-
for i in eachindex(args)
66-
args[i] = _rec_renaming!(args[i], naming_scheme)
67-
end
68-
end
69-
rhs
70-
end
71-
72-
function extract_var_order(eq)
73-
# We assume that the differential with the highest order is always going to be in the LHS
74-
dv = eq.lhs
75-
var = dv.args[1]
76-
order = dv.op.order
77-
return (var, order)
78-
end
54+
extract_var_order(eq::DiffEq) = (eq.var, eq.D.order)
7955

8056
export ode_order_lowering

test/system_construction.jl

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using Test
1111
eqs = [D(x) ~ σ*(y-x),
1212
D(y) ~ x*-z)-y,
1313
D(z) ~ x*y - β*z]
14-
de = DiffEqSystem(eqs,[t],[x,y,z],[σ,ρ,β])
14+
de = DiffEqSystem(eqs,t,[x,y,z],[σ,ρ,β])
1515
ModelingToolkit.generate_ode_function(de)
1616
ModelingToolkit.generate_ode_function(de;version=ModelingToolkit.SArrayFunction)
1717
jac_expr = ModelingToolkit.generate_ode_jacobian(de)
@@ -20,10 +20,11 @@ f = ODEFunction(de)
2020
ModelingToolkit.generate_ode_iW(de)
2121

2222
# Differential equation with automatic extraction of variables
23-
de2 = DiffEqSystem(eqs, [t])
23+
de2 = DiffEqSystem(eqs, t)
2424

2525
function test_vars_extraction(de, de2)
26-
for el in (:ivs, :dvs, :ps)
26+
@test de.iv == de2.iv
27+
for el in (:dvs, :ps)
2728
names2 = sort(collect(var.name for var in getfield(de2,el)))
2829
names = sort(collect(var.name for var in getfield(de,el)))
2930
@test names2 == names
@@ -67,34 +68,21 @@ end
6768
@Unknown u(t) u_tt(t) u_t(t) x_t(t)
6869
eqs = [D3(u) ~ 2(D2(u)) + D(u) + D(x) + 1
6970
D2(x) ~ D(x) + 2]
70-
de = DiffEqSystem(eqs, [t])
71+
de = DiffEqSystem(eqs, t)
7172
de1 = ode_order_lowering(de)
7273
lowered_eqs = [D(u_tt) ~ 2u_tt + u_t + x_t + 1
7374
D(x_t) ~ x_t + 2
7475
D(u_t) ~ u_tt
7576
D(u) ~ u_t
7677
D(x) ~ x_t]
77-
function test_eqs(eqs1, eqs2)
78-
length(eqs1) == length(eqs2) || return false
79-
eq = true
80-
for (eq1, eq2) in zip(eqs1, eqs2)
81-
lhs1, lhs2 = eq1.lhs, eq2.lhs
82-
typeof(lhs1) === typeof(lhs2) || return false
83-
for f in fieldnames(typeof(lhs1))
84-
eq = eq & isequal(getfield(lhs1, f), getfield(lhs2, f))
85-
end
86-
eq = eq & isequal(eq1.rhs, eq2.rhs)
87-
end
88-
eq
89-
end
90-
@test test_eqs(de1.eqs, lowered_eqs)
78+
@test de1.eqs == convert.(ModelingToolkit.DiffEq, lowered_eqs)
9179

9280
# Internal calculations
9381
a = y - x
9482
eqs = [D(x) ~ σ*a,
9583
D(y) ~ x*-z)-y,
9684
D(z) ~ x*y - β*z]
97-
de = DiffEqSystem(eqs,[t],[x,y,z],[σ,ρ,β])
85+
de = DiffEqSystem(eqs,t,[x,y,z],[σ,ρ,β])
9886
ModelingToolkit.generate_ode_function(de)
9987
jac = ModelingToolkit.calculate_jacobian(de)
10088
f = ODEFunction(de)
@@ -118,8 +106,8 @@ ModelingToolkit.generate_nlsys_function(ns)
118106
_x = y / C
119107
eqs = [D(x) ~ -A*x,
120108
D(y) ~ A*x - B*_x]
121-
de = DiffEqSystem(eqs,[t],[x,y],[A,B,C])
122-
test_vars_extraction(de, DiffEqSystem(eqs,[t]))
109+
de = DiffEqSystem(eqs,t,[x,y],[A,B,C])
110+
test_vars_extraction(de, DiffEqSystem(eqs,t))
123111
test_vars_extraction(de, DiffEqSystem(eqs))
124112
@test eval(ModelingToolkit.generate_ode_function(de))([0.0,0.0],[1.0,2.0],[1,2,3],0.0) -1/3
125113

0 commit comments

Comments
 (0)