Skip to content

Commit 254d106

Browse files
Merge pull request #122 from JuliaDiffEq/hg/fix/toexpr
Variable function fixes
2 parents 4b6e441 + 17e1e4e commit 254d106

File tree

4 files changed

+32
-14
lines changed

4 files changed

+32
-14
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ and parameters. Therefore we label them as follows:
2828
using ModelingToolkit
2929

3030
# Define some variables
31-
@parameters t σ ρ β
31+
@parameters t() σ() ρ() β()
3232
@variables x(t) y(t) z(t)
3333
@derivatives D'~t
3434
```
@@ -80,22 +80,22 @@ state of the previous ODE. This is the nonlinear system defined by where the
8080
derivatives are zero. We use (unknown) variables for our nonlinear system.
8181
8282
```julia
83-
@variables x y z
84-
@parameters σ ρ β
83+
@variables x() y() z()
84+
@parameters σ() ρ() β()
8585

8686
# Define a nonlinear system
8787
eqs = [0 ~ σ*(y-x),
8888
0 ~ x*-z)-y,
8989
0 ~ x*y - β*z]
9090
ns = NonlinearSystem(eqs, [x,y,z])
91-
nlsys_func = generate_function(ns, [x,y,z], [ρ,σ,β])
91+
nlsys_func = generate_function(ns, [x,y,z], [σ,ρ,β])
9292
```
9393
9494
which generates:
9595
9696
```julia
9797
:((##364, u, p)->begin
98-
let (x, z, y, ρ, σ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
98+
let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
9999
##364[1] = σ * (y - x)
100100
##364[2] = x * (ρ - z) - y
101101
##364[3] = x * y - β * z

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,6 @@ function calculate_jacobian(sys::ODESystem)
9494
return jac
9595
end
9696

97-
function generate_jacobian(sys::ODESystem; version::FunctionVersion = ArrayFunction)
98-
jac = calculate_jacobian(sys)
99-
return build_function(jac, sys.dvs, sys.ps, (sys.iv.name,); version = version)
100-
end
101-
10297
struct ODEToExpr
10398
sys::ODESystem
10499
end
@@ -113,6 +108,11 @@ function (f::ODEToExpr)(O::Operation)
113108
end
114109
(f::ODEToExpr)(x) = convert(Expr, x)
115110

111+
function generate_jacobian(sys::ODESystem; version::FunctionVersion = ArrayFunction)
112+
jac = calculate_jacobian(sys)
113+
return build_function(jac, sys.dvs, sys.ps, (sys.iv.name,), ODEToExpr(sys); version = version)
114+
end
115+
116116
function generate_function(sys::ODESystem, dvs, ps; version::FunctionVersion = ArrayFunction)
117117
rhss = [deq.rhs for deq sys.eqs]
118118
dvs′ = [clean(dv) for dv dvs]

src/systems/nonlinear/nonlinear_system.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,26 @@ end
4545

4646
function generate_jacobian(sys::NonlinearSystem; version::FunctionVersion = ArrayFunction)
4747
jac = calculate_jacobian(sys)
48-
return build_function(jac, clean.(sys.vs), sys.ps; version = version)
48+
return build_function(jac, clean.(sys.vs), sys.ps, (), NLSysToExpr(sys); version = version)
4949
end
5050

51+
struct NLSysToExpr
52+
sys::NonlinearSystem
53+
end
54+
function (f::NLSysToExpr)(O::Operation)
55+
any(isequal(O), f.sys.vs) && return O.op.name # variables
56+
if isa(O.op, Variable)
57+
isempty(O.args) && return O.op.name # 0-ary parameters
58+
return build_expr(:call, Any[O.op.name; f.(O.args)])
59+
end
60+
return build_expr(:call, Any[O.op; f.(O.args)])
61+
end
62+
(f::NLSysToExpr)(x) = convert(Expr, x)
63+
64+
5165
function generate_function(sys::NonlinearSystem, vs, ps; version::FunctionVersion = ArrayFunction)
5266
rhss = [eq.rhs for eq sys.eqs]
5367
vs′ = [clean(v) for v vs]
5468
ps′ = [clean(p) for p ps]
55-
return build_function(rhss, vs′, ps′; version = version)
69+
return build_function(rhss, vs′, ps′, (), NLSysToExpr(sys); version = version)
5670
end

test/system_construction.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,12 @@ eqs = [0 ~ σ*(y-x),
119119
0 ~ x*y - β*z]
120120
ns = NonlinearSystem(eqs, [x,y,z])
121121
test_nlsys_inference("standard", ns, (x, y, z), (σ, ρ, β))
122-
123-
generate_function(ns, [x,y,z], [σ,ρ,β])
122+
@test begin
123+
f = eval(generate_function(ns, [x,y,z], [σ,ρ,β]))
124+
du = [0.0, 0.0, 0.0]
125+
f(du, [1,2,3], [1,2,3])
126+
du [1, -3, -7]
127+
end
124128

125129
@derivatives D'~t
126130
@parameters A() B() C()

0 commit comments

Comments
 (0)