Skip to content

Commit 61f4914

Browse files
Unify generate_function
1 parent ca926fc commit 61f4914

File tree

5 files changed

+55
-60
lines changed

5 files changed

+55
-60
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ This can then generate the function. For example, we can see the
5656
generated code via:
5757

5858
```julia
59-
ModelingToolkit.generate_ode_function(de)
59+
generate_function(de)
6060

6161
## Which returns:
6262
:((du, u, p, t)->begin
@@ -88,7 +88,7 @@ eqs = [0 ~ σ*(y-x),
8888
0 ~ x*-z)-y,
8989
0 ~ x*y - β*z]
9090
ns = NonlinearSystem(eqs)
91-
nlsys_func = ModelingToolkit.generate_nlsys_function(ns)
91+
nlsys_func = generate_function(ns)
9292
```
9393

9494
which generates:
@@ -270,7 +270,7 @@ eqs = [0 ~ σ*a,
270270
0 ~ x*-z)-y,
271271
0 ~ x*y - β*z]
272272
ns = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β])
273-
nlsys_func = ModelingToolkit.generate_nlsys_function(ns)
273+
nlsys_func = generate_function(ns)
274274
```
275275

276276
expands to:

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
export DiffEqSystem, ODEFunction
2+
3+
14
using Base: RefValue
25

36

@@ -12,6 +15,7 @@ function Base.convert(::Type{DiffEq}, eq::Equation)
1215
isintermediate(eq) && throw(ArgumentError("intermediate equation received"))
1316
return DiffEq(eq.lhs.op, eq.lhs.args[1], eq.rhs)
1417
end
18+
Base.convert(::Type{Equation}, eq::DiffEq) = Equation(eq.D(eq.var), eq.rhs)
1519
Base.:(==)(a::DiffEq, b::DiffEq) = (a.D, a.var, a.rhs) == (b.D, b.var, b.rhs)
1620
get_args(eq::DiffEq) = Expression[eq.var, eq.rhs]
1721

@@ -42,26 +46,6 @@ function DiffEqSystem(eqs, iv)
4246
end
4347

4448

45-
function generate_ode_function(sys::DiffEqSystem; version::FunctionVersion = ArrayFunction)
46-
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(sys.dvs)]
47-
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(sys.ps )]
48-
(ls, rs) = collect(zip(var_pairs..., param_pairs...))
49-
50-
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
51-
sys_exprs = build_expr(:tuple, [convert(Expr, eq.rhs) for eq sys.eqs])
52-
let_expr = Expr(:let, var_eqs, sys_exprs)
53-
54-
if version === ArrayFunction
55-
:((du,u,p,t) -> du .= $let_expr)
56-
elseif version === SArrayFunction
57-
:((u,p,t) -> begin
58-
du = $let_expr
59-
T = StaticArrays.similar_type(typeof(u), eltype(du))
60-
T(du)
61-
end)
62-
end
63-
end
64-
6549
function calculate_jacobian(sys::DiffEqSystem, simplify=true)
6650
isempty(sys.jac[]) || return sys.jac[] # use cached Jacobian, if possible
6751
rhs = [eq.rhs for eq in sys.eqs]
@@ -71,6 +55,8 @@ function calculate_jacobian(sys::DiffEqSystem, simplify=true)
7155
return jac
7256
end
7357

58+
system_eqs(sys::DiffEqSystem) = collect(Equation, sys.eqs)
59+
system_extras(::DiffEqSystem) = Equation[]
7460
system_vars(sys::DiffEqSystem) = sys.dvs
7561
system_params(sys::DiffEqSystem) = sys.ps
7662

@@ -108,14 +94,10 @@ function generate_ode_iW(sys::DiffEqSystem, simplify=true)
10894
end
10995

11096
function DiffEqBase.ODEFunction(sys::DiffEqSystem; version::FunctionVersion = ArrayFunction)
111-
expr = generate_ode_function(sys; version = version)
97+
expr = generate_function(sys; version = version)
11298
if version === ArrayFunction
11399
ODEFunction{true}(eval(expr))
114100
elseif version === SArrayFunction
115101
ODEFunction{false}(eval(expr))
116102
end
117103
end
118-
119-
120-
export DiffEqSystem, ODEFunction
121-
export generate_ode_function
Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
export NonlinearSystem
2+
3+
14
struct NonlinearSystem <: AbstractSystem
25
eqs::Vector{Equation}
36
vs::Vector{Variable}
@@ -9,25 +12,9 @@ function NonlinearSystem(eqs)
912
NonlinearSystem(eqs, vs, ps)
1013
end
1114

12-
iscalc(eq) = isequal(eq.lhs, Constant(0))
13-
14-
function generate_nlsys_function(sys::NonlinearSystem)
15-
sys_eqs, calc_eqs = partition(iscalc, sys.eqs)
16-
17-
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(sys.vs)]
18-
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(sys.ps)]
19-
calc_pairs = [(eq.lhs.name, convert(Expr, eq.rhs)) for eq calc_eqs if isa(eq.lhs, Variable)]
20-
(ls, rs) = collect(zip(var_pairs..., param_pairs..., calc_pairs...))
21-
22-
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
23-
sys_exprs = build_expr(:tuple, [convert(Expr, eq.rhs) for eq sys_eqs])
24-
let_expr = Expr(:let, var_eqs, sys_exprs)
2515

26-
:((du,u,p) -> du .= $let_expr)
27-
end
28-
29-
function calculate_jacobian(sys::NonlinearSystem,simplify=true)
30-
sys_eqs, calc_eqs = partition(iscalc, sys.eqs)
16+
function calculate_jacobian(sys::NonlinearSystem, simplify=true)
17+
sys_eqs, calc_eqs = system_eqs(sys), filter(iscalc, sys.eqs)
3118
rhs = [eq.rhs for eq in sys_eqs]
3219

3320
for calc_eq calc_eqs
@@ -39,8 +26,9 @@ function calculate_jacobian(sys::NonlinearSystem,simplify=true)
3926
sys_exprs
4027
end
4128

29+
iscalc(eq) = !isequal(eq.lhs, Constant(0))
30+
31+
system_eqs(sys::NonlinearSystem) = filter(!iscalc, sys.eqs)
32+
system_extras(sys::NonlinearSystem) = filter(eq -> isa(eq.lhs, Variable), sys.eqs)
4233
system_vars(sys::NonlinearSystem) = sys.vs
4334
system_params(sys::NonlinearSystem) = sys.ps
44-
45-
export NonlinearSystem
46-
export generate_nlsys_function

src/systems/systems.jl

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
export generate_jacobian
1+
export generate_jacobian, generate_function
22

33

44
abstract type AbstractSystem end
55

6-
7-
function system_vars end
6+
function system_eqs end
7+
function system_extras end
8+
function system_vars end
89
function system_params end
910

1011
function generate_jacobian(sys::AbstractSystem, simplify = true)
@@ -17,3 +18,27 @@ function generate_jacobian(sys::AbstractSystem, simplify = true)
1718
block = expr_arr_to_block(exprs)
1819
:((J,u,p,t) -> $(block))
1920
end
21+
22+
function generate_function(sys::AbstractSystem; version::FunctionVersion = ArrayFunction)
23+
sys_eqs, calc_eqs = system_eqs(sys), system_extras(sys)
24+
vs, ps = system_vars(sys), system_params(sys)
25+
26+
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(vs)]
27+
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(ps)]
28+
calc_pairs = [(eq.lhs.name, convert(Expr, eq.rhs)) for eq calc_eqs]
29+
(ls, rs) = collect(zip(var_pairs..., param_pairs..., calc_pairs...))
30+
31+
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
32+
sys_exprs = build_expr(:tuple, [convert(Expr, eq.rhs) for eq sys_eqs])
33+
let_expr = Expr(:let, var_eqs, sys_exprs)
34+
35+
if version === ArrayFunction
36+
:((du,u,p,t) -> du .= $let_expr)
37+
elseif version === SArrayFunction
38+
:((u,p,t) -> begin
39+
du = $let_expr
40+
T = StaticArrays.similar_type(typeof(u), eltype(du))
41+
T(du)
42+
end)
43+
end
44+
end

test/system_construction.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ eqs = [D(x) ~ σ*(y-x),
1212
D(y) ~ x*-z)-y,
1313
D(z) ~ x*y - β*z]
1414
de = DiffEqSystem(eqs,t,[x,y,z],[σ,ρ,β])
15-
ModelingToolkit.generate_ode_function(de)
16-
ModelingToolkit.generate_ode_function(de;version=ModelingToolkit.SArrayFunction)
15+
generate_function(de)
16+
generate_function(de;version=ModelingToolkit.SArrayFunction)
1717
jac_expr = generate_jacobian(de)
1818
jac = ModelingToolkit.calculate_jacobian(de)
1919
f = ODEFunction(de)
@@ -39,7 +39,7 @@ test_vars_extraction(de, de2)
3939
D(y) ~ x*-z)-y,
4040
D(z) ~ x*y - β*z]
4141
de = DiffEqSystem(eqs,[t],[x,y,z],[σ,ρ,β])
42-
ModelingToolkit.generate_ode_function(de)
42+
generate_function(de)
4343

4444
#=
4545
```julia
@@ -83,7 +83,7 @@ eqs = [D(x) ~ σ*a,
8383
D(y) ~ x*-z)-y,
8484
D(z) ~ x*y - β*z]
8585
de = DiffEqSystem(eqs,t,[x,y,z],[σ,ρ,β])
86-
ModelingToolkit.generate_ode_function(de)
86+
generate_function(de)
8787
jac = ModelingToolkit.calculate_jacobian(de)
8888
f = ODEFunction(de)
8989

@@ -99,7 +99,7 @@ for el in (:vs, :ps)
9999
@test names2 == names
100100
end
101101

102-
ModelingToolkit.generate_nlsys_function(ns)
102+
generate_function(ns)
103103

104104
@Deriv D'~t
105105
@Param A B C
@@ -110,7 +110,7 @@ de = DiffEqSystem(eqs,t,[x,y],[A,B,C])
110110
test_vars_extraction(de, DiffEqSystem(eqs,t))
111111
test_vars_extraction(de, DiffEqSystem(eqs))
112112
@test begin
113-
f = eval(ModelingToolkit.generate_ode_function(de))
113+
f = eval(generate_function(de))
114114
du = [0.0,0.0]
115115
f(du, [1.0,2.0], [1,2,3], 0.0)
116116
du [-1, -1/3]
@@ -137,7 +137,7 @@ jac = ModelingToolkit.calculate_jacobian(ns)
137137
@test jac[3,2] == x
138138
@test jac[3,3] == -1 * β
139139
end
140-
nlsys_func = ModelingToolkit.generate_nlsys_function(ns)
140+
nlsys_func = generate_function(ns)
141141
jac_func = generate_jacobian(ns)
142142
f = @eval eval(nlsys_func)
143143

@@ -148,6 +148,6 @@ eqs = [a ~ y-x,
148148
0 ~ x*-z)-y,
149149
0 ~ x*y - β*z]
150150
ns = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β])
151-
nlsys_func = ModelingToolkit.generate_nlsys_function(ns)
151+
nlsys_func = generate_function(ns)
152152
jac = ModelingToolkit.calculate_jacobian(ns)
153153
jac = generate_jacobian(ns)

0 commit comments

Comments
 (0)