Skip to content

Commit 14393d3

Browse files
Merge pull request #94 from JuliaDiffEq/hg/refactor/systems
Refactor system functions
2 parents bd58adb + 6d6b5bf commit 14393d3

File tree

7 files changed

+130
-148
lines changed

7 files changed

+130
-148
lines changed

README.md

Lines changed: 24 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -56,24 +56,16 @@ This can then generate the function. For example, we can see the
5656
generated code via:
5757

5858
```julia
59-
ModelingToolkit.generate_ode_function(de)
59+
generate_function(de)
6060

6161
## Which returns:
62-
:((du, u, p, t)->begin
63-
x = u[1]
64-
y = u[2]
65-
z = u[3]
66-
σ = p[1]
67-
ρ = p[2]
68-
β = p[3]
69-
x_t = σ * (y - x)
70-
y_t = x *- z) - y
71-
z_t = x * y - β * z
72-
du[1] = x_t
73-
du[2] = y_t
74-
du[3] = z_t
75-
end
76-
end)
62+
:((##363, u, p, t)->begin
63+
let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
64+
##363[1] = σ * (y - x)
65+
##363[2] = x * (ρ - z) - y
66+
##363[3] = x * y - β * z
67+
end
68+
end)
7769
```
7870
7971
and get the generated function via:
@@ -97,25 +89,19 @@ eqs = [0 ~ σ*(y-x),
9789
0 ~ x*-z)-y,
9890
0 ~ x*y - β*z]
9991
ns = NonlinearSystem(eqs)
100-
nlsys_func = ModelingToolkit.generate_nlsys_function(ns)
92+
nlsys_func = generate_function(ns)
10193
```
10294
10395
which generates:
10496
10597
```julia
106-
(du, u, p)->begin # C:\Users\Chris\.julia\v0.6\ModelingToolkit\src\systems.jl, line 51:
107-
begin # C:\Users\Chris\.julia\v0.6\ModelingToolkit\src\utils.jl, line 2:
108-
y = u[1]
109-
x = u[2]
110-
z = u[3]
111-
σ = p[1]
112-
ρ = p[2]
113-
β = p[3]
114-
resid[1] = σ * (y - x)
115-
resid[2] = x *- z) - y
116-
resid[3] = x * y - β * z
117-
end
118-
end
98+
:((##364, u, p)->begin
99+
let (x, z, y, ρ, σ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
100+
##364[1] = σ * (y - x)
101+
##364[2] = x * (ρ - z) - y
102+
##364[3] = x * y - β * z
103+
end
104+
end)
119105
```
120106
121107
We can use this to build a nonlinear function for use with NLsolve.jl:
@@ -287,26 +273,19 @@ eqs = [0 ~ σ*a,
287273
0 ~ x*-z)-y,
288274
0 ~ x*y - β*z]
289275
ns = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β])
290-
nlsys_func = ModelingToolkit.generate_nlsys_function(ns)
276+
nlsys_func = generate_function(ns)
291277
```
292278
293279
expands to:
294280
295281
```julia
296-
:((du, u, p)->begin # C:\Users\Chris\.julia\v0.6\ModelingToolkit\src\systems.jl, line 85:
297-
begin # C:\Users\Chris\.julia\v0.6\ModelingToolkit\src\utils.jl, line 2:
298-
x = u[1]
299-
y = u[2]
300-
z = u[3]
301-
σ = p[1]
302-
ρ = p[2]
303-
β = p[3]
304-
a = y - x
305-
resid[1] = σ * a
306-
resid[2] = x *- z) - y
307-
resid[3] = x * y - β * z
308-
end
309-
end)
282+
:((##365, u, p)->begin
283+
let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
284+
##365[1] = σ * (y - x)
285+
##365[2] = x * (ρ - z) - y
286+
##365[3] = x * y - β * z
287+
end
288+
end)
310289
```
311290
312291
In addition, the Jacobian calculations take into account intermediate variables

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import MacroTools: splitdef, combinedef
99
abstract type Expression <: Number end
1010
abstract type AbstractOperation <: Expression end
1111
abstract type AbstractComponent <: Expression end
12-
abstract type AbstractSystem end
1312

1413
include("variables.jl")
1514

@@ -25,6 +24,7 @@ function caclulate_jacobian end
2524
include("operations.jl")
2625
include("differentials.jl")
2726
include("equations.jl")
27+
include("systems/systems.jl")
2828
include("systems/diffeqs/diffeqsystem.jl")
2929
include("systems/diffeqs/first_order_transform.jl")
3030
include("systems/nonlinear/nonlinear_system.jl")

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 20 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
export DiffEqSystem, ODEFunction
2+
3+
14
using Base: RefValue
25

36

@@ -42,34 +45,7 @@ function DiffEqSystem(eqs, iv)
4245
end
4346

4447

45-
function generate_ode_function(sys::DiffEqSystem; version::FunctionVersion = ArrayFunction)
46-
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
47-
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
48-
sys_exprs = build_equals_expr.(sys.eqs)
49-
if version === ArrayFunction
50-
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.iv.name)"))) for i in eachindex(sys.dvs)]
51-
exprs = vcat(var_exprs,param_exprs,sys_exprs,dvar_exprs)
52-
block = expr_arr_to_block(exprs)
53-
:((du,u,p,t)->$(toexpr(block)))
54-
elseif version === SArrayFunction
55-
dvar_exprs = [:($(Symbol("$(sys.dvs[i].name)_$(sys.iv.name)"))) for i in eachindex(sys.dvs)]
56-
svector_expr = quote
57-
E = eltype(tuple($(dvar_exprs...)))
58-
T = StaticArrays.similar_type(typeof(u), E)
59-
T($(dvar_exprs...))
60-
end
61-
exprs = vcat(var_exprs,param_exprs,sys_exprs,svector_expr)
62-
block = expr_arr_to_block(exprs)
63-
:((u,p,t)->$(toexpr(block)))
64-
end
65-
end
66-
67-
function build_equals_expr(eq::DiffEq)
68-
lhs = Symbol(eq.var.name, :_, eq.D.x.name)
69-
return :($lhs = $(convert(Expr, eq.rhs)))
70-
end
71-
72-
function calculate_jacobian(sys::DiffEqSystem, simplify=true)
48+
function calculate_jacobian(sys::DiffEqSystem)
7349
isempty(sys.jac[]) || return sys.jac[] # use cached Jacobian, if possible
7450
rhs = [eq.rhs for eq in sys.eqs]
7551

@@ -78,20 +54,19 @@ function calculate_jacobian(sys::DiffEqSystem, simplify=true)
7854
return jac
7955
end
8056

81-
function generate_ode_jacobian(sys::DiffEqSystem, simplify=true)
82-
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
83-
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
84-
jac = calculate_jacobian(sys, simplify)
85-
jac_exprs = [:(J[$i,$j] = $(convert(Expr, jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
86-
exprs = vcat(var_exprs,param_exprs,vec(jac_exprs))
87-
block = expr_arr_to_block(exprs)
88-
:((J,u,p,t)->$(block))
57+
function generate_jacobian(sys::DiffEqSystem; version::FunctionVersion = ArrayFunction)
58+
jac = calculate_jacobian(sys)
59+
return build_function(jac, sys.dvs, sys.ps, (sys.iv.name,); version = version)
60+
end
61+
62+
function generate_function(sys::DiffEqSystem; version::FunctionVersion = ArrayFunction)
63+
rhss = [eq.rhs for eq sys.eqs]
64+
return build_function(rhss, sys.dvs, sys.ps, (sys.iv.name,); version = version)
8965
end
9066

91-
function generate_ode_iW(sys::DiffEqSystem, simplify=true)
92-
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
93-
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
94-
jac = calculate_jacobian(sys, simplify)
67+
68+
function generate_ode_iW(sys::DiffEqSystem, simplify=true; version::FunctionVersion = ArrayFunction)
69+
jac = calculate_jacobian(sys)
9570

9671
gam = Parameter(:gam)
9772

@@ -110,25 +85,18 @@ function generate_ode_iW(sys::DiffEqSystem, simplify=true)
11085
iW_t = simplify_constants.(iW_t)
11186
end
11287

113-
iW_exprs = [:(iW[$i,$j] = $(convert(Expr, iW[i,j]))) for i in 1:size(iW,1), j in 1:size(iW,2)]
114-
exprs = vcat(var_exprs,param_exprs,vec(iW_exprs))
115-
block = expr_arr_to_block(exprs)
88+
vs, ps = sys.dvs, sys.ps
89+
iW_func = build_function(iW , vs, ps, (:gam,:t); version = version)
90+
iW_t_func = build_function(iW_t, vs, ps, (:gam,:t); version = version)
11691

117-
iW_t_exprs = [:(iW[$i,$j] = $(convert(Expr, iW_t[i,j]))) for i in 1:size(iW_t,1), j in 1:size(iW_t,2)]
118-
exprs = vcat(var_exprs,param_exprs,vec(iW_t_exprs))
119-
block2 = expr_arr_to_block(exprs)
120-
:((iW,u,p,gam,t)->$(block)),:((iW,u,p,gam,t)->$(block2))
92+
return (iW_func, iW_t_func)
12193
end
12294

12395
function DiffEqBase.ODEFunction(sys::DiffEqSystem; version::FunctionVersion = ArrayFunction)
124-
expr = generate_ode_function(sys; version = version)
96+
expr = generate_function(sys; version = version)
12597
if version === ArrayFunction
12698
ODEFunction{true}(eval(expr))
12799
elseif version === SArrayFunction
128100
ODEFunction{false}(eval(expr))
129101
end
130102
end
131-
132-
133-
export DiffEqSystem, ODEFunction
134-
export generate_ode_function
Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
1+
export NonlinearSystem
2+
3+
4+
struct NLEq
5+
rhs::Expression
6+
end
7+
function Base.convert(::Type{NLEq}, eq::Equation)
8+
isequal(eq.lhs, Constant(0)) || return NLEq(eq.rhs - eq.lhs)
9+
return NLEq(eq.rhs)
10+
end
11+
Base.:(==)(a::NLEq, b::NLEq) = a.rhs == b.rhs
12+
get_args(eq::NLEq) = Expression[eq.rhs]
13+
114
struct NonlinearSystem <: AbstractSystem
2-
eqs::Vector{Equation}
15+
eqs::Vector{NLEq}
316
vs::Vector{Variable}
417
ps::Vector{Variable}
518
end
@@ -9,40 +22,19 @@ function NonlinearSystem(eqs)
922
NonlinearSystem(eqs, vs, ps)
1023
end
1124

12-
function generate_nlsys_function(sys::NonlinearSystem)
13-
var_exprs = [:($(sys.vs[i].name) = u[$i]) for i in 1:length(sys.vs)]
14-
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
15-
sys_eqs, calc_eqs = partition(eq -> isequal(eq.lhs, Constant(0)), sys.eqs)
16-
calc_exprs = [:($(eq.lhs.name) = $(eq.rhs)) for eq in calc_eqs if isa(eq.lhs, Variable)]
17-
sys_exprs = [:($(Symbol("resid[$i]")) = $(sys_eqs[i].rhs)) for i in eachindex(sys_eqs)]
1825

19-
exprs = vcat(var_exprs,param_exprs,calc_exprs,sys_exprs)
20-
block = expr_arr_to_block(exprs)
21-
:((du,u,p)->$(block))
26+
function calculate_jacobian(sys::NonlinearSystem)
27+
rhs = [eq.rhs for eq in sys.eqs]
28+
jac = expand_derivatives.(calculate_jacobian(rhs, sys.vs))
29+
return jac
2230
end
2331

24-
function calculate_jacobian(sys::NonlinearSystem,simplify=true)
25-
sys_eqs, calc_eqs = partition(eq -> isequal(eq.lhs, Constant(0)), sys.eqs)
26-
rhs = [eq.rhs for eq in sys_eqs]
27-
28-
for calc_eq calc_eqs
29-
find_replace!.(rhs, calc_eq.lhs, calc_eq.rhs)
30-
end
31-
32-
sys_exprs = calculate_jacobian(rhs,sys.vs)
33-
sys_exprs = Expression[expand_derivatives(expr) for expr in sys_exprs]
34-
sys_exprs
32+
function generate_jacobian(sys::NonlinearSystem; version::FunctionVersion = ArrayFunction)
33+
jac = calculate_jacobian(sys)
34+
return build_function(jac, sys.vs, sys.ps; version = version)
3535
end
3636

37-
function generate_nlsys_jacobian(sys::NonlinearSystem,simplify=true)
38-
var_exprs = [:($(sys.vs[i].name) = u[$i]) for i in 1:length(sys.vs)]
39-
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
40-
jac = calculate_jacobian(sys,simplify)
41-
jac_exprs = [:(J[$i,$j] = $(convert(Expr, jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
42-
exprs = vcat(var_exprs,param_exprs,vec(jac_exprs))
43-
block = expr_arr_to_block(exprs)
44-
:((J,u,p,t)->$(block))
37+
function generate_function(sys::NonlinearSystem; version::FunctionVersion = ArrayFunction)
38+
rhss = [eq.rhs for eq sys.eqs]
39+
return build_function(rhss, sys.vs, sys.ps; version = version)
4540
end
46-
47-
export NonlinearSystem
48-
export generate_nlsys_function

src/systems/systems.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
export generate_jacobian, generate_function
2+
3+
4+
abstract type AbstractSystem end
5+
6+
function generate_jacobian end
7+
function generate_function end
8+
9+
function build_function(rhss, vs, ps, args = (); version::FunctionVersion)
10+
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(vs)]
11+
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(ps)]
12+
(ls, rs) = zip(var_pairs..., param_pairs...)
13+
14+
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
15+
16+
if version === ArrayFunction
17+
X = gensym()
18+
sys_exprs = [:($X[$i] = $(convert(Expr, rhs))) for (i, rhs) enumerate(rhss)]
19+
let_expr = Expr(:let, var_eqs, build_expr(:block, sys_exprs))
20+
:(($X,u,p,$(args...)) -> $let_expr)
21+
elseif version === SArrayFunction
22+
sys_expr = build_expr(:tuple, [convert(Expr, rhs) for rhs rhss])
23+
let_expr = Expr(:let, var_eqs, sys_expr)
24+
:((u,p,$(args...)) -> begin
25+
X = $let_expr
26+
T = StaticArrays.similar_type(typeof(u), eltype(X))
27+
T(X)
28+
end)
29+
end
30+
end

src/utils.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,9 @@ function flatten_expr!(x)
3131
x
3232
end
3333

34-
toexpr(ex) = MacroTools.postwalk(x -> isa(x, Expression) ? convert(Expr, x) : x, ex)
35-
3634
function partition(f, xs)
3735
idxs = map(f, xs)
38-
not_idxs = eachindex(xs) .∉ (idxs,)
39-
return (xs[idxs], xs[not_idxs])
36+
return (xs[idxs], xs[(!).(idxs)])
4037
end
4138

4239
is_constant(::Constant) = true

0 commit comments

Comments
 (0)