Skip to content

Commit ecefdde

Browse files
Merge pull request #754 from SciML/nonlinearsystem
Complete NonlinearSystem
2 parents 4cdcd9d + cb4df64 commit ecefdde

File tree

3 files changed

+150
-28
lines changed

3 files changed

+150
-28
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
3636
[compat]
3737
ArrayInterface = "2.8"
3838
DataStructures = "0.17, 0.18"
39-
DiffEqBase = "6.48.1"
39+
DiffEqBase = "6.54.0"
4040
DiffEqJump = "6.7.5"
4141
DiffRules = "0.1, 1.0"
4242
Distributions = "0.24"
@@ -64,6 +64,7 @@ julia = "1.2"
6464
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
6565
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
6666
GalacticOptim = "a75be94c-b780-496d-a8a9-0878b188d577"
67+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
6768
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
6869
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
6970
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -72,4 +73,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
7273
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7374

7475
[targets]
75-
test = ["Dagger", "ForwardDiff", "GalacticOptim", "OrdinaryDiffEq", "Optim", "Random", "SteadyStateDiffEq", "Test", "StochasticDiffEq"]
76+
test = ["Dagger", "ForwardDiff", "GalacticOptim", "NonlinearSolve", "OrdinaryDiffEq", "Optim", "Random", "SteadyStateDiffEq", "Test", "StochasticDiffEq"]

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 141 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,132 @@ jacobian_sparsity(sys::NonlinearSystem) =
8787
jacobian_sparsity([eq.rhs for eq equations(sys)],
8888
states(sys))
8989

90+
function DiffEqBase.NonlinearFunction(sys::NonlinearSystem, args...; kwargs...)
91+
NonlinearFunction{true}(sys, args...; kwargs...)
92+
end
93+
94+
"""
95+
```julia
96+
function DiffEqBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = states(sys),
97+
ps = parameters(sys);
98+
version = nothing,
99+
jac = false,
100+
sparse = false,
101+
kwargs...) where {iip}
102+
```
103+
104+
Create an `NonlinearFunction` from the [`NonlinearSystem`](@ref). The arguments
105+
`dvs` and `ps` are used to set the order of the dependent variable and parameter
106+
vectors, respectively.
107+
"""
108+
function DiffEqBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = states(sys),
109+
ps = parameters(sys), u0 = nothing;
110+
version = nothing,
111+
jac = false,
112+
eval_expression = true,
113+
sparse = false, simplify=false,
114+
kwargs...) where {iip}
115+
116+
f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
117+
f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in f_gen) : f_gen
118+
f(u,p) = f_oop(u,p)
119+
f(du,u,p) = f_iip(du,u,p)
120+
121+
if jac
122+
jac_gen = generate_jacobian(sys, dvs, ps;
123+
simplify=simplify, sparse = sparse,
124+
expression=Val{eval_expression}, kwargs...)
125+
jac_oop,jac_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in jac_gen) : jac_gen
126+
_jac(u,p) = jac_oop(u,p)
127+
_jac(J,u,p) = jac_iip(J,u,p)
128+
else
129+
_jac = nothing
130+
end
131+
132+
NonlinearFunction{iip}(f,
133+
jac = _jac === nothing ? nothing : _jac,
134+
jac_prototype = sparse ? similar(sys.jac[],Float64) : nothing,
135+
syms = Symbol.(states(sys)))
136+
end
137+
138+
"""
139+
```julia
140+
function DiffEqBase.NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = states(sys),
141+
ps = parameters(sys);
142+
version = nothing,
143+
jac = false,
144+
sparse = false,
145+
kwargs...) where {iip}
146+
```
147+
148+
Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref).
149+
The arguments `dvs` and `ps` are used to set the order of the dependent
150+
variable and parameter vectors, respectively.
151+
"""
152+
struct NonlinearFunctionExpr{iip} end
153+
154+
function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = states(sys),
155+
ps = parameters(sys), u0 = nothing;
156+
version = nothing, tgrad=false,
157+
jac = false,
158+
linenumbers = false,
159+
sparse = false, simplify=false,
160+
kwargs...) where {iip}
161+
162+
idx = iip ? 2 : 1
163+
f = generate_function(sys, dvs, ps; expression=Val{true}, kwargs...)[idx]
164+
165+
if jac
166+
_jac = generate_jacobian(sys, dvs, ps;
167+
sparse=sparse, simplify=simplify,
168+
expression=Val{true}, kwargs...)[idx]
169+
else
170+
_jac = :nothing
171+
end
172+
173+
jp_expr = sparse ? :(similar($(sys.jac[]),Float64)) : :nothing
174+
175+
ex = quote
176+
f = $f
177+
jac = $_jac
178+
NonlinearFunction{$iip}(f,
179+
jac = jac,
180+
jac_prototype = $jp_expr,
181+
syms = $(Symbol.(states(sys))))
182+
end
183+
!linenumbers ? striplines(ex) : ex
184+
end
185+
186+
function process_NonlinearProblem(constructor, sys::NonlinearSystem,u0map,parammap;
187+
version = nothing,
188+
jac = false,
189+
checkbounds = false, sparse = false,
190+
simplify=false,
191+
linenumbers = true, parallel=SerialForm(),
192+
eval_expression = true,
193+
kwargs...)
194+
dvs = states(sys)
195+
ps = parameters(sys)
196+
u0map′ = lower_mapnames(u0map)
197+
u0 = varmap_to_vars(u0map′,dvs; defaults=get_default_u0(sys))
198+
199+
if !(parammap isa DiffEqBase.NullParameters)
200+
parammap′ = lower_mapnames(parammap)
201+
p = varmap_to_vars(parammap′,ps; defaults=get_default_p(sys))
202+
else
203+
p = ps
204+
end
205+
206+
f = constructor(sys,dvs,ps,u0;jac=jac,checkbounds=checkbounds,
207+
linenumbers=linenumbers,parallel=parallel,simplify=simplify,
208+
sparse=sparse,eval_expression=eval_expression,kwargs...)
209+
return f, u0, p
210+
end
211+
212+
function DiffEqBase.NonlinearProblem(sys::NonlinearSystem, args...; kwargs...)
213+
NonlinearProblem{true}(sys, args...; kwargs...)
214+
end
215+
90216
"""
91217
```julia
92218
function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,
@@ -101,19 +227,9 @@ Generates an NonlinearProblem from a NonlinearSystem and allows for automaticall
101227
symbolically calculating numerical enhancements.
102228
"""
103229
function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,
104-
parammap=DiffEqBase.NullParameters();
105-
jac = false, sparse=false,
106-
checkbounds = false,
107-
linenumbers = true, parallel=SerialForm(),
108-
kwargs...) where iip
109-
dvs = states(sys)
110-
ps = parameters(sys)
111-
112-
f = generate_function(sys;checkbounds=checkbounds,linenumbers=linenumbers,
113-
parallel=parallel,sparse=sparse,expression=Val{false})
114-
u0 = varmap_to_vars(u0map,dvs; defaults=get_default_u0(sys))
115-
p = varmap_to_vars(parammap,ps; defaults=get_default_p(sys))
116-
NonlinearProblem(f,u0,p;kwargs...)
230+
parammap=DiffEqBase.NullParameters();kwargs...) where iip
231+
f, u0, p = process_NonlinearProblem(NonlinearFunction{iip}, sys, u0map, parammap; kwargs...)
232+
NonlinearProblem{iip}(f,u0,p;kwargs...)
117233
end
118234

119235
"""
@@ -132,23 +248,22 @@ numerical enhancements.
132248
"""
133249
struct NonlinearProblemExpr{iip} end
134250

251+
function NonlinearProblemExpr(sys::NonlinearSystem, args...; kwargs...)
252+
NonlinearProblemExpr{true}(sys, args...; kwargs...)
253+
end
254+
135255
function NonlinearProblemExpr{iip}(sys::NonlinearSystem,u0map,
136-
parammap=DiffEqBase.NullParameters();
137-
jac = false, sparse=false,
138-
checkbounds = false,
139-
linenumbers = false, parallel=SerialForm(),
140-
kwargs...) where iip
141-
dvs = states(sys)
142-
ps = parameters(sys)
256+
parammap=DiffEqBase.NullParameters();
257+
kwargs...) where iip
258+
259+
f, u0, p = process_NonlinearProblem(NonlinearFunctionExpr{iip}, sys, u0map, parammap; kwargs...)
260+
linenumbers = get(kwargs, :linenumbers, true)
143261

144-
f = generate_function(sys;checkbounds=checkbounds,linenumbers=linenumbers,
145-
parallel=parallel,sparse=sparse,expression=Val{true})
146-
u0 = varmap_to_vars(u0map,dvs; defaults=get_default_u0(sys))
147-
p = varmap_to_vars(parammap,ps; defaults=get_default_p(sys))
148-
quote
262+
ex = quote
149263
f = $f
150264
u0 = $u0
151265
p = $p
152-
NonlinearProblem(f,u0,p;kwargs...)
266+
NonlinearProblem(f,u0,p;$(kwargs...))
153267
end
268+
!linenumbers ? striplines(ex) : ex
154269
end

test/nonlinearsystem.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using ModelingToolkit, StaticArrays, LinearAlgebra
22
using DiffEqBase, SparseArrays
33
using Test
4+
using NonlinearSolve
45
using ModelingToolkit: value
56

67
canonequal(a, b) = isequal(simplify(a), simplify(b))
@@ -62,9 +63,14 @@ eqs = [0 ~ σ*a,
6263
0 ~ x*y - β*z]
6364
ns = NonlinearSystem(eqs, [x,y,z], [σ,ρ,β])
6465
nlsys_func = generate_function(ns, [x,y,z], [σ,ρ,β])
66+
nf = NonlinearFunction(ns)
6567
jac = calculate_jacobian(ns)
6668

6769
@test ModelingToolkit.jacobian_sparsity(ns).colptr == sparse(jac).colptr
6870
@test ModelingToolkit.jacobian_sparsity(ns).rowval == sparse(jac).rowval
6971

7072
jac = generate_jacobian(ns)
73+
74+
prob = NonlinearProblem(ns,ones(3),ones(3))
75+
sol = solve(prob,NewtonRaphson())
76+
@test sol.u[1] sol.u[2]

0 commit comments

Comments
 (0)