Skip to content

Commit a11d28b

Browse files
Fix NonlinearSystem
1 parent b87bb66 commit a11d28b

File tree

6 files changed

+47
-56
lines changed

6 files changed

+47
-56
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ derivatives are zero. We use (unknown) variables for our nonlinear system.
8888
eqs = [0 ~ σ*(y-x),
8989
0 ~ x*-z)-y,
9090
0 ~ x*y - β*z]
91-
ns = NonlinearSystem(eqs)
91+
ns = NonlinearSystem(eqs, [x,y,z])
9292
nlsys_func = generate_function(ns)
9393
```
9494
@@ -272,7 +272,7 @@ a = y - x
272272
eqs = [0 ~ σ*a,
273273
0 ~ x*-z)-y,
274274
0 ~ x*y - β*z]
275-
ns = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β])
275+
ns = NonlinearSystem(eqs, [x,y,z])
276276
nlsys_func = generate_function(ns)
277277
```
278278

src/equations.jl

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,3 @@ Base.:(==)(a::Equation, b::Equation) = isequal((a.lhs, a.rhs), (b.lhs, b.rhs))
1010
Base.:~(lhs::Expression, rhs::Expression) = Equation(lhs, rhs)
1111
Base.:~(lhs::Expression, rhs::Number ) = Equation(lhs, rhs)
1212
Base.:~(lhs::Number , rhs::Expression) = Equation(lhs, rhs)
13-
14-
_is_known(O::Operation) = O.op.known
15-
_is_unknown(O::Operation) = !O.op.known
16-
17-
function extract_elements(eqs, predicates)
18-
result = [Variable[] for p predicates]
19-
vars = foldl(vars!, eqs; init=Set{Variable}())
20-
21-
for var vars
22-
for (i, p) enumerate(predicates)
23-
p(var) && (push!(result[i], var); break)
24-
end
25-
end
26-
27-
return result
28-
end
29-
30-
vars(exprs) = foldl(vars!, exprs; init = Set{Variable}())
31-
function vars!(vars, O)
32-
isa(O, Operation) || return vars
33-
for arg O.args
34-
if isa(arg, Operation)
35-
isa(arg.op, Variable) && push!(vars, arg.op)
36-
vars!(vars, arg)
37-
end
38-
end
39-
40-
return vars
41-
end

src/systems/nonlinear/nonlinear_system.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,26 @@ function Base.convert(::Type{NLEq}, eq::Equation)
99
return NLEq(eq.rhs)
1010
end
1111
Base.:(==)(a::NLEq, b::NLEq) = a.rhs == b.rhs
12-
get_args(eq::NLEq) = Expression[eq.rhs]
1312

1413
struct NonlinearSystem <: AbstractSystem
1514
eqs::Vector{NLEq}
16-
vs::Vector{Variable}
15+
vs::Vector{Expression}
1716
ps::Vector{Variable}
17+
function NonlinearSystem(eqs, vs)
18+
rhss = [eq.rhs for eq eqs]
19+
ps = reduce(, map(_find_params(vs), rhss); init = vnil())
20+
new(eqs, vs, collect(ps))
21+
end
1822
end
1923

20-
function NonlinearSystem(eqs)
21-
vs, ps = extract_elements(eqs, [_is_unknown, _is_known])
22-
NonlinearSystem(eqs, vs, ps)
24+
vnil() = Set{Variable}()
25+
_find_params(vs) = Base.Fix2(_find_params, vs)
26+
function _find_params(O, vs)
27+
isa(O, Operation) || return vnil()
28+
any(isequal(O), vs) && return vnil()
29+
ps = reduce(, map(_find_params(vs), O.args); init = vnil())
30+
isa(O.op, Variable) && push!(ps, O.op)
31+
return ps
2332
end
2433

2534

@@ -31,10 +40,12 @@ end
3140

3241
function generate_jacobian(sys::NonlinearSystem; version::FunctionVersion = ArrayFunction)
3342
jac = calculate_jacobian(sys)
34-
return build_function(jac, sys.vs, sys.ps; version = version)
43+
return build_function(jac, clean.(sys.vs), sys.ps; version = version)
3544
end
3645

37-
function generate_function(sys::NonlinearSystem; version::FunctionVersion = ArrayFunction)
46+
function generate_function(sys::NonlinearSystem, vs, ps; version::FunctionVersion = ArrayFunction)
3847
rhss = [eq.rhs for eq sys.eqs]
39-
return build_function(rhss, sys.vs, sys.ps; version = version)
48+
vs′ = [clean(v) for v vs]
49+
ps′ = [clean(p) for p ps]
50+
return build_function(rhss, vs′, ps′; version = version)
4051
end

src/utils.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,17 @@ Base.occursin(t::Expression, x::Expression) = isequal(x, t)
6969

7070
clean(x::Variable) = x
7171
clean(O::Operation) = isa(O.op, Variable) ? O.op : throw(ArgumentError("invalid variable: $(O.op)"))
72+
73+
74+
vars(exprs) = foldl(vars!, exprs; init = Set{Variable}())
75+
function vars!(vars, O)
76+
isa(O, Operation) || return vars
77+
for arg O.args
78+
if isa(arg, Operation)
79+
isa(arg.op, Variable) && push!(vars, arg.op)
80+
vars!(vars, arg)
81+
end
82+
end
83+
84+
return vars
85+
end

test/derivatives.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,10 @@ d2 = D(sin(t)*cos(t))
3131
@test isequal(expand_derivatives(d1), t*cos(t)+sin(t))
3232
@test isequal(expand_derivatives(d2), simplify_constants(cos(t)*cos(t)+sin(t)*(-1*sin(t))))
3333

34-
@test_broken begin
3534
eqs = [0 ~ σ*(y-x),
3635
0 ~ x*-z)-y,
3736
0 ~ x*y - β*z]
38-
sys = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β])
37+
sys = NonlinearSystem(eqs, [x,y,z])
3938
jac = calculate_jacobian(sys)
4039
@test isequal(jac[1,1], σ*-1)
4140
@test isequal(jac[1,2], σ)
@@ -46,7 +45,6 @@ jac = calculate_jacobian(sys)
4645
@test isequal(jac[3,1], y)
4746
@test isequal(jac[3,2], x)
4847
@test isequal(jac[3,3], -1*β)
49-
end
5048

5149
# Variable dependence checking in differentiation
5250
@variables a(t) b(a)

test/system_construction.jl

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ function test_diffeq_inference(name, de, iv, dvs, ps)
1313
@test Set([p.name for p de.ps ]) == Set(ps)
1414
end
1515
end
16+
function test_nlsys_inference(name, de, vs, ps)
17+
@testset "NonlinearSystem construction: $name" begin
18+
@test Set(vs) == Set(vs)
19+
@test Set([p.name for p de.ps]) == Set(ps)
20+
end
21+
end
1622

1723
# Define a differential equation
1824
eqs = [D(x) ~ σ*(y-x),
@@ -96,21 +102,14 @@ jac = calculate_jacobian(de)
96102
f = ODEFunction(de)
97103
end
98104

99-
@test_broken begin
100105
# Define a nonlinear system
101106
eqs = [0 ~ σ*(y-x),
102107
0 ~ x*-z)-y,
103108
0 ~ x*y - β*z]
104-
ns = NonlinearSystem(eqs,[x,y,z],[t,σ,ρ,β])
105-
ns2 = NonlinearSystem(eqs)
106-
for el in (:vs, :ps)
107-
names2 = sort(collect(var.name for var in getfield(ns2,el)))
108-
names = sort(collect(var.name for var in getfield(ns,el)))
109-
@test names2 == names
110-
end
109+
ns = NonlinearSystem(eqs, [x,y,z])
110+
test_nlsys_inference("standard", ns, (x, y, z), (, , ))
111111

112112
generate_function(ns, [x,y,z], [σ,ρ,β])
113-
end
114113

115114
@derivatives D'~t
116115
@parameters A() B() C()
@@ -125,16 +124,15 @@ de = ODESystem(eqs)
125124
du [-1, -1/3]
126125
end
127126

128-
@test_broken begin
129127
# Now nonlinear system with only variables
130-
@variables x y z
128+
@variables x() y() z()
131129
@parameters σ() ρ() β()
132130

133131
# Define a nonlinear system
134132
eqs = [0 ~ σ*(y-x),
135133
0 ~ x*-z)-y,
136134
0 ~ x*y - β*z]
137-
ns = NonlinearSystem(eqs, [x,y,z], [σ,ρ,β])
135+
ns = NonlinearSystem(eqs, [x,y,z])
138136
jac = calculate_jacobian(ns)
139137
@testset "nlsys jacobian" begin
140138
@test isequal(jac[1,1], σ * -1)
@@ -156,8 +154,7 @@ f = @eval eval(nlsys_func)
156154
eqs = [0 ~ σ*a,
157155
0 ~ x*-z)-y,
158156
0 ~ x*y - β*z]
159-
ns = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β])
157+
ns = NonlinearSystem(eqs, [x,y,z])
160158
nlsys_func = generate_function(ns, [x,y,z], [σ,ρ,β])
161159
jac = calculate_jacobian(ns)
162160
jac = generate_jacobian(ns)
163-
end

0 commit comments

Comments
 (0)