Skip to content

Commit 40b0370

Browse files
Refactor nlsys equation storage
Disallow intermediate equations.
1 parent 61f4914 commit 40b0370

File tree

4 files changed

+22
-27
lines changed

4 files changed

+22
-27
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function DiffEqSystem(eqs, iv)
4646
end
4747

4848

49-
function calculate_jacobian(sys::DiffEqSystem, simplify=true)
49+
function calculate_jacobian(sys::DiffEqSystem)
5050
isempty(sys.jac[]) || return sys.jac[] # use cached Jacobian, if possible
5151
rhs = [eq.rhs for eq in sys.eqs]
5252

@@ -56,15 +56,14 @@ function calculate_jacobian(sys::DiffEqSystem, simplify=true)
5656
end
5757

5858
system_eqs(sys::DiffEqSystem) = collect(Equation, sys.eqs)
59-
system_extras(::DiffEqSystem) = Equation[]
6059
system_vars(sys::DiffEqSystem) = sys.dvs
6160
system_params(sys::DiffEqSystem) = sys.ps
6261

6362

6463
function generate_ode_iW(sys::DiffEqSystem, simplify=true)
6564
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
6665
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
67-
jac = calculate_jacobian(sys, simplify)
66+
jac = calculate_jacobian(sys)
6867

6968
gam = Parameter(:gam)
7069

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
export NonlinearSystem
22

33

4+
struct NLEq
5+
rhs::Expression
6+
end
7+
function Base.convert(::Type{NLEq}, eq::Equation)
8+
isequal(eq.lhs, Constant(0)) || throw(ArgumentError("nonzero lhs received"))
9+
return NLEq(eq.rhs)
10+
end
11+
Base.convert(::Type{Equation}, eq::NLEq) = Equation(0, eq.rhs)
12+
413
struct NonlinearSystem <: AbstractSystem
5-
eqs::Vector{Equation}
14+
eqs::Vector{NLEq}
615
vs::Vector{Variable}
716
ps::Vector{Variable}
817
end
@@ -13,22 +22,12 @@ function NonlinearSystem(eqs)
1322
end
1423

1524

16-
function calculate_jacobian(sys::NonlinearSystem, simplify=true)
17-
sys_eqs, calc_eqs = system_eqs(sys), filter(iscalc, sys.eqs)
18-
rhs = [eq.rhs for eq in sys_eqs]
19-
20-
for calc_eq calc_eqs
21-
find_replace!.(rhs, calc_eq.lhs, calc_eq.rhs)
22-
end
23-
24-
sys_exprs = calculate_jacobian(rhs,sys.vs)
25-
sys_exprs = Expression[expand_derivatives(expr) for expr in sys_exprs]
26-
sys_exprs
25+
function calculate_jacobian(sys::NonlinearSystem)
26+
rhs = [eq.rhs for eq in sys.eqs]
27+
jac = expand_derivatives.(calculate_jacobian(rhs, sys.vs))
28+
return jac
2729
end
2830

29-
iscalc(eq) = !isequal(eq.lhs, Constant(0))
30-
31-
system_eqs(sys::NonlinearSystem) = filter(!iscalc, sys.eqs)
32-
system_extras(sys::NonlinearSystem) = filter(eq -> isa(eq.lhs, Variable), sys.eqs)
31+
system_eqs(sys::NonlinearSystem) = collect(Equation, sys.eqs)
3332
system_vars(sys::NonlinearSystem) = sys.vs
3433
system_params(sys::NonlinearSystem) = sys.ps

src/systems/systems.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,27 @@ export generate_jacobian, generate_function
44
abstract type AbstractSystem end
55

66
function system_eqs end
7-
function system_extras end
87
function system_vars end
98
function system_params end
109

11-
function generate_jacobian(sys::AbstractSystem, simplify = true)
10+
function generate_jacobian(sys::AbstractSystem)
1211
vs, ps = system_vars(sys), system_params(sys)
1312
var_exprs = [:($(vs[i].name) = u[$i]) for i in eachindex(vs)]
1413
param_exprs = [:($(ps[i].name) = p[$i]) for i in eachindex(ps)]
15-
jac = calculate_jacobian(sys, simplify)
14+
jac = calculate_jacobian(sys)
1615
jac_exprs = [:(J[$i,$j] = $(convert(Expr, jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
1716
exprs = vcat(var_exprs, param_exprs, vec(jac_exprs))
1817
block = expr_arr_to_block(exprs)
1918
:((J,u,p,t) -> $(block))
2019
end
2120

2221
function generate_function(sys::AbstractSystem; version::FunctionVersion = ArrayFunction)
23-
sys_eqs, calc_eqs = system_eqs(sys), system_extras(sys)
22+
sys_eqs = system_eqs(sys)
2423
vs, ps = system_vars(sys), system_params(sys)
2524

2625
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(vs)]
2726
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(ps)]
28-
calc_pairs = [(eq.lhs.name, convert(Expr, eq.rhs)) for eq calc_eqs]
29-
(ls, rs) = collect(zip(var_pairs..., param_pairs..., calc_pairs...))
27+
(ls, rs) = collect(zip(var_pairs..., param_pairs...))
3028

3129
var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
3230
sys_exprs = build_expr(:tuple, [convert(Expr, eq.rhs) for eq sys_eqs])

test/system_construction.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,7 @@ f = @eval eval(nlsys_func)
143143

144144
# Intermediate calculations
145145
# Define a nonlinear system
146-
eqs = [a ~ y-x,
147-
0 ~ σ*a,
146+
eqs = [0 ~ σ*a,
148147
0 ~ x*-z)-y,
149148
0 ~ x*y - β*z]
150149
ns = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β])

0 commit comments

Comments
 (0)