Skip to content

Commit 6d6b5bf

Browse files
Fix independent variable usage
1 parent e69c962 commit 6d6b5bf

File tree

4 files changed

+32
-35
lines changed

4 files changed

+32
-35
lines changed

README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ nlsys_func = generate_function(ns)
9595
which generates:
9696
9797
```julia
98-
:((##366, u, p, t)->begin
99-
let (y, z, x, ρ, σ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
100-
##366[1] = σ * (y - x)
101-
##366[2] = x * (ρ - z) - y
102-
##366[3] = x * y - β * z
98+
:((##364, u, p)->begin
99+
let (x, z, y, ρ, σ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
100+
##364[1] = σ * (y - x)
101+
##364[2] = x * (ρ - z) - y
102+
##364[3] = x * y - β * z
103103
end
104104
end)
105105
```
@@ -279,11 +279,11 @@ nlsys_func = generate_function(ns)
279279
expands to:
280280
281281
```julia
282-
:((##367, u, p, t)->begin
282+
:((##365, u, p)->begin
283283
let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
284-
##367[1] = σ * (y - x)
285-
##367[2] = x * (ρ - z) - y
286-
##367[3] = x * y - β * z
284+
##365[1] = σ * (y - x)
285+
##365[2] = x * (ρ - z) - y
286+
##365[3] = x * y - β * z
287287
end
288288
end)
289289
```

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ function Base.convert(::Type{DiffEq}, eq::Equation)
1515
isintermediate(eq) && throw(ArgumentError("intermediate equation received"))
1616
return DiffEq(eq.lhs.op, eq.lhs.args[1], eq.rhs)
1717
end
18-
Base.convert(::Type{Equation}, eq::DiffEq) = Equation(eq.D(eq.var), eq.rhs)
1918
Base.:(==)(a::DiffEq, b::DiffEq) = (a.D, a.var, a.rhs) == (b.D, b.var, b.rhs)
2019
get_args(eq::DiffEq) = Expression[eq.var, eq.rhs]
2120

@@ -55,9 +54,15 @@ function calculate_jacobian(sys::DiffEqSystem)
5554
return jac
5655
end
5756

58-
system_eqs(sys::DiffEqSystem) = collect(Equation, sys.eqs)
59-
system_vars(sys::DiffEqSystem) = sys.dvs
60-
system_params(sys::DiffEqSystem) = sys.ps
57+
function generate_jacobian(sys::DiffEqSystem; version::FunctionVersion = ArrayFunction)
58+
jac = calculate_jacobian(sys)
59+
return build_function(jac, sys.dvs, sys.ps, (sys.iv.name,); version = version)
60+
end
61+
62+
function generate_function(sys::DiffEqSystem; version::FunctionVersion = ArrayFunction)
63+
rhss = [eq.rhs for eq sys.eqs]
64+
return build_function(rhss, sys.dvs, sys.ps, (sys.iv.name,); version = version)
65+
end
6166

6267

6368
function generate_ode_iW(sys::DiffEqSystem, simplify=true; version::FunctionVersion = ArrayFunction)
@@ -80,7 +85,7 @@ function generate_ode_iW(sys::DiffEqSystem, simplify=true; version::FunctionVers
8085
iW_t = simplify_constants.(iW_t)
8186
end
8287

83-
vs, ps = system_vars(sys), system_params(sys)
88+
vs, ps = sys.dvs, sys.ps
8489
iW_func = build_function(iW , vs, ps, (:gam,:t); version = version)
8590
iW_t_func = build_function(iW_t, vs, ps, (:gam,:t); version = version)
8691

src/systems/nonlinear/nonlinear_system.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ function Base.convert(::Type{NLEq}, eq::Equation)
88
isequal(eq.lhs, Constant(0)) || return NLEq(eq.rhs - eq.lhs)
99
return NLEq(eq.rhs)
1010
end
11-
Base.convert(::Type{Equation}, eq::NLEq) = Equation(0, eq.rhs)
1211
Base.:(==)(a::NLEq, b::NLEq) = a.rhs == b.rhs
1312
get_args(eq::NLEq) = Expression[eq.rhs]
1413

@@ -30,6 +29,12 @@ function calculate_jacobian(sys::NonlinearSystem)
3029
return jac
3130
end
3231

33-
system_eqs(sys::NonlinearSystem) = collect(Equation, sys.eqs)
34-
system_vars(sys::NonlinearSystem) = sys.vs
35-
system_params(sys::NonlinearSystem) = sys.ps
32+
function generate_jacobian(sys::NonlinearSystem; version::FunctionVersion = ArrayFunction)
33+
jac = calculate_jacobian(sys)
34+
return build_function(jac, sys.vs, sys.ps; version = version)
35+
end
36+
37+
function generate_function(sys::NonlinearSystem; version::FunctionVersion = ArrayFunction)
38+
rhss = [eq.rhs for eq sys.eqs]
39+
return build_function(rhss, sys.vs, sys.ps; version = version)
40+
end

src/systems/systems.jl

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,13 @@ export generate_jacobian, generate_function
33

44
abstract type AbstractSystem end
55

6-
function system_eqs end
7-
function system_vars end
8-
function system_params end
6+
function generate_jacobian end
7+
function generate_function end
98

10-
function generate_jacobian(sys::AbstractSystem; version = ArrayFunction)
11-
vs, ps = system_vars(sys), system_params(sys)
12-
jac = calculate_jacobian(sys)
13-
return build_function(jac, vs, ps, (:t,); version = version)
14-
end
15-
16-
function generate_function(sys::AbstractSystem; version::FunctionVersion = ArrayFunction)
17-
sys_eqs = system_eqs(sys)
18-
vs, ps = system_vars(sys), system_params(sys)
19-
return build_function([eq.rhs for eq sys_eqs], vs, ps, (:t,); version = version)
20-
end
21-
22-
function build_function(rhss, vs, ps, args; version::FunctionVersion)
9+
function build_function(rhss, vs, ps, args = (); version::FunctionVersion)
2310
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(vs)]
2411
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(ps)]
25-
(ls, rs) = collect(zip(var_pairs..., param_pairs...))
12+
(ls, rs) = zip(var_pairs..., param_pairs...)
2613

2714
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
2815

0 commit comments

Comments
 (0)