Skip to content

Commit ca926fc

Browse files
Refactor function generation
Closes #93.
1 parent c3dc123 commit ca926fc

File tree

4 files changed

+42
-66
lines changed

4 files changed

+42
-66
lines changed

README.md

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,11 @@ ModelingToolkit.generate_ode_function(de)
6060

6161
## Which returns:
6262
:((du, u, p, t)->begin
63-
x = u[1]
64-
y = u[2]
65-
z = u[3]
66-
σ = p[1]
67-
ρ = p[2]
68-
β = p[3]
69-
x_t = σ * (y - x)
70-
y_t = x *- z) - y
71-
z_t = x * y - β * z
72-
du[1] = x_t
73-
du[2] = y_t
74-
du[3] = z_t
75-
end
76-
end)
63+
du .= let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
64+
* (y - x), x *- z) - y, x * y - β * z)
65+
end
66+
end)
67+
7768
```
7869

7970
and get the generated function via:
@@ -103,19 +94,11 @@ nlsys_func = ModelingToolkit.generate_nlsys_function(ns)
10394
which generates:
10495

10596
```julia
106-
(du, u, p)->begin # C:\Users\Chris\.julia\v0.6\ModelingToolkit\src\systems.jl, line 51:
107-
begin # C:\Users\Chris\.julia\v0.6\ModelingToolkit\src\utils.jl, line 2:
108-
y = u[1]
109-
x = u[2]
110-
z = u[3]
111-
σ = p[1]
112-
ρ = p[2]
113-
β = p[3]
114-
resid[1] = σ * (y - x)
115-
resid[2] = x *- z) - y
116-
resid[3] = x * y - β * z
117-
end
118-
end
97+
:((du, u, p)->begin
98+
du .= let (y, x, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
99+
* (y - x), x *- z) - y, x * y - β * z)
100+
end
101+
end)
119102
```
120103

121104
We can use this to build a nonlinear function for use with NLsolve.jl:
@@ -293,20 +276,12 @@ nlsys_func = ModelingToolkit.generate_nlsys_function(ns)
293276
expands to:
294277

295278
```julia
296-
:((du, u, p)->begin # C:\Users\Chris\.julia\v0.6\ModelingToolkit\src\systems.jl, line 85:
297-
begin # C:\Users\Chris\.julia\v0.6\ModelingToolkit\src\utils.jl, line 2:
298-
x = u[1]
299-
y = u[2]
300-
z = u[3]
301-
σ = p[1]
302-
ρ = p[2]
303-
β = p[3]
304-
a = y - x
305-
resid[1] = σ * a
306-
resid[2] = x *- z) - y
307-
resid[3] = x * y - β * z
308-
end
309-
end)
279+
:((du, u, p)->begin
280+
du .= let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
281+
* (y - x), x *- z) - y, x * y - β * z)
282+
end
283+
end)
284+
310285
```
311286

312287
In addition, the Jacobian calculations take into account intermediate variables

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,24 +43,22 @@ end
4343

4444

4545
function generate_ode_function(sys::DiffEqSystem; version::FunctionVersion = ArrayFunction)
46-
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
47-
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
46+
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(sys.dvs)]
47+
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(sys.ps )]
48+
(ls, rs) = collect(zip(var_pairs..., param_pairs...))
49+
50+
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
4851
sys_exprs = build_expr(:tuple, [convert(Expr, eq.rhs) for eq sys.eqs])
52+
let_expr = Expr(:let, var_eqs, sys_exprs)
53+
4954
if version === ArrayFunction
50-
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.iv.name)"))) for i in eachindex(sys.dvs)]
51-
du_expr = :(du .= $sys_exprs)
52-
exprs = vcat(var_exprs,param_exprs,du_expr)
53-
block = expr_arr_to_block(exprs)
54-
:((du,u,p,t)->$(toexpr(block)))
55+
:((du,u,p,t) -> du .= $let_expr)
5556
elseif version === SArrayFunction
56-
svector_expr = quote
57-
du = $sys_exprs
57+
:((u,p,t) -> begin
58+
du = $let_expr
5859
T = StaticArrays.similar_type(typeof(u), eltype(du))
5960
T(du)
60-
end
61-
exprs = vcat(var_exprs,param_exprs,svector_expr)
62-
block = expr_arr_to_block(exprs)
63-
:((u,p,t)->$(toexpr(block)))
61+
end)
6462
end
6563
end
6664

src/systems/nonlinear/nonlinear_system.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,25 @@ function NonlinearSystem(eqs)
99
NonlinearSystem(eqs, vs, ps)
1010
end
1111

12+
iscalc(eq) = isequal(eq.lhs, Constant(0))
13+
1214
function generate_nlsys_function(sys::NonlinearSystem)
13-
var_exprs = [:($(sys.vs[i].name) = u[$i]) for i in 1:length(sys.vs)]
14-
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
15-
sys_eqs, calc_eqs = partition(eq -> isequal(eq.lhs, Constant(0)), sys.eqs)
16-
calc_exprs = [:($(eq.lhs.name) = $(eq.rhs)) for eq in calc_eqs if isa(eq.lhs, Variable)]
17-
sys_exprs = [:($(Symbol("resid[$i]")) = $(sys_eqs[i].rhs)) for i in eachindex(sys_eqs)]
18-
19-
exprs = vcat(var_exprs,param_exprs,calc_exprs,sys_exprs)
20-
block = expr_arr_to_block(exprs)
21-
:((du,u,p)->$(block))
15+
sys_eqs, calc_eqs = partition(iscalc, sys.eqs)
16+
17+
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(sys.vs)]
18+
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(sys.ps)]
19+
calc_pairs = [(eq.lhs.name, convert(Expr, eq.rhs)) for eq calc_eqs if isa(eq.lhs, Variable)]
20+
(ls, rs) = collect(zip(var_pairs..., param_pairs..., calc_pairs...))
21+
22+
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
23+
sys_exprs = build_expr(:tuple, [convert(Expr, eq.rhs) for eq sys_eqs])
24+
let_expr = Expr(:let, var_eqs, sys_exprs)
25+
26+
:((du,u,p) -> du .= $let_expr)
2227
end
2328

2429
function calculate_jacobian(sys::NonlinearSystem,simplify=true)
25-
sys_eqs, calc_eqs = partition(eq -> isequal(eq.lhs, Constant(0)), sys.eqs)
30+
sys_eqs, calc_eqs = partition(iscalc, sys.eqs)
2631
rhs = [eq.rhs for eq in sys_eqs]
2732

2833
for calc_eq calc_eqs

src/utils.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ function flatten_expr!(x)
3131
x
3232
end
3333

34-
toexpr(ex) = MacroTools.postwalk(x -> isa(x, Expression) ? convert(Expr, x) : x, ex)
35-
3634
function partition(f, xs)
3735
idxs = map(f, xs)
3836
return (xs[idxs], xs[(!).(idxs)])

0 commit comments

Comments
 (0)