Skip to content

Commit 2b5b014

Browse files
Complete NonlinearSystem
1 parent 4cdcd9d commit 2b5b014

File tree

2 files changed

+142
-27
lines changed

2 files changed

+142
-27
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
3636
[compat]
3737
ArrayInterface = "2.8"
3838
DataStructures = "0.17, 0.18"
39-
DiffEqBase = "6.48.1"
39+
DiffEqBase = "6.53.6"
4040
DiffEqJump = "6.7.5"
4141
DiffRules = "0.1, 1.0"
4242
Distributions = "0.24"

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 141 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,132 @@ jacobian_sparsity(sys::NonlinearSystem) =
8787
jacobian_sparsity([eq.rhs for eq equations(sys)],
8888
states(sys))
8989

90+
function DiffEqBase.NonlinearFunction(sys::NonlinearSystem, args...; kwargs...)
91+
NonlinearFunction{true}(sys, args...; kwargs...)
92+
end
93+
94+
"""
95+
```julia
96+
function DiffEqBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = states(sys),
97+
ps = parameters(sys);
98+
version = nothing,
99+
jac = false,
100+
sparse = false,
101+
kwargs...) where {iip}
102+
```
103+
104+
Create an `NonlinearFunction` from the [`NonlinearSystem`](@ref). The arguments
105+
`dvs` and `ps` are used to set the order of the dependent variable and parameter
106+
vectors, respectively.
107+
"""
108+
function DiffEqBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = states(sys),
109+
ps = parameters(sys), u0 = nothing;
110+
version = nothing,
111+
jac = false,
112+
eval_expression = true,
113+
sparse = false, simplify=false,
114+
kwargs...) where {iip}
115+
116+
f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
117+
f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in f_gen) : f_gen
118+
f(u,p) = f_oop(u,p)
119+
f(du,u,p) = f_iip(du,u,p)
120+
121+
if jac
122+
jac_gen = generate_jacobian(sys, dvs, ps;
123+
simplify=simplify, sparse = sparse,
124+
expression=Val{eval_expression}, kwargs...)
125+
jac_oop,jac_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in jac_gen) : jac_gen
126+
_jac(u,p) = jac_oop(u,p)
127+
_jac(J,u,p) = jac_iip(J,u,p)
128+
else
129+
_jac = nothing
130+
end
131+
132+
NonlinearFunction{iip}(f,
133+
jac = _jac === nothing ? nothing : _jac,
134+
jac_prototype = sparse ? similar(sys.jac[],Float64) : nothing,
135+
syms = Symbol.(states(sys)))
136+
end
137+
138+
"""
139+
```julia
140+
function DiffEqBase.NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = states(sys),
141+
ps = parameters(sys);
142+
version = nothing,
143+
jac = false,
144+
sparse = false,
145+
kwargs...) where {iip}
146+
```
147+
148+
Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref).
149+
The arguments `dvs` and `ps` are used to set the order of the dependent
150+
variable and parameter vectors, respectively.
151+
"""
152+
struct NonlinearFunctionExpr{iip} end
153+
154+
function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = states(sys),
155+
ps = parameters(sys), u0 = nothing;
156+
version = nothing, tgrad=false,
157+
jac = false,
158+
linenumbers = false,
159+
sparse = false, simplify=false,
160+
kwargs...) where {iip}
161+
162+
idx = iip ? 2 : 1
163+
f = generate_function(sys, dvs, ps; expression=Val{true}, kwargs...)[idx]
164+
165+
if jac
166+
_jac = generate_jacobian(sys, dvs, ps;
167+
sparse=sparse, simplify=simplify,
168+
expression=Val{true}, kwargs...)[idx]
169+
else
170+
_jac = :nothing
171+
end
172+
173+
jp_expr = sparse ? :(similar($(sys.jac[]),Float64)) : :nothing
174+
175+
ex = quote
176+
f = $f
177+
jac = $_jac
178+
NonlinearFunction{$iip}(f,
179+
jac = jac,
180+
jac_prototype = $jp_expr,
181+
syms = $(Symbol.(states(sys))))
182+
end
183+
!linenumbers ? striplines(ex) : ex
184+
end
185+
186+
function process_NonlinearProblem(constructor, sys::NonlinearSystem,u0map,parammap;
187+
version = nothing,
188+
jac = false,
189+
checkbounds = false, sparse = false,
190+
simplify=false,
191+
linenumbers = true, parallel=SerialForm(),
192+
eval_expression = true,
193+
kwargs...)
194+
dvs = states(sys)
195+
ps = parameters(sys)
196+
u0map′ = lower_mapnames(u0map)
197+
u0 = varmap_to_vars(u0map′,dvs; defaults=get_default_u0(sys))
198+
199+
if !(parammap isa DiffEqBase.NullParameters)
200+
parammap′ = lower_mapnames(parammap)
201+
p = varmap_to_vars(parammap′,ps; defaults=get_default_p(sys))
202+
else
203+
p = ps
204+
end
205+
206+
f = constructor(sys,dvs,ps,u0;jac=jac,checkbounds=checkbounds,
207+
linenumbers=linenumbers,parallel=parallel,simplify=simplify,
208+
sparse=sparse,eval_expression=eval_expression,kwargs...)
209+
return f, u0, p
210+
end
211+
212+
function DiffEqBase.NonlinearProblem(sys::NonlinearSystem, args...; kwargs...)
213+
NonlinearProblem{true}(sys, args...; kwargs...)
214+
end
215+
90216
"""
91217
```julia
92218
function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,
@@ -101,19 +227,9 @@ Generates an NonlinearProblem from a NonlinearSystem and allows for automaticall
101227
symbolically calculating numerical enhancements.
102228
"""
103229
function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,
104-
parammap=DiffEqBase.NullParameters();
105-
jac = false, sparse=false,
106-
checkbounds = false,
107-
linenumbers = true, parallel=SerialForm(),
108-
kwargs...) where iip
109-
dvs = states(sys)
110-
ps = parameters(sys)
111-
112-
f = generate_function(sys;checkbounds=checkbounds,linenumbers=linenumbers,
113-
parallel=parallel,sparse=sparse,expression=Val{false})
114-
u0 = varmap_to_vars(u0map,dvs; defaults=get_default_u0(sys))
115-
p = varmap_to_vars(parammap,ps; defaults=get_default_p(sys))
116-
NonlinearProblem(f,u0,p;kwargs...)
230+
parammap=DiffEqBase.NullParameters();kwargs...) where iip
231+
f, u0, p = process_NonlinearProblem(NonlinearFunction{iip}, sys, u0map, parammap; kwargs...)
232+
NonlinearProblem{iip}(f,u0,p;kwargs...)
117233
end
118234

119235
"""
@@ -132,23 +248,22 @@ numerical enhancements.
132248
"""
133249
struct NonlinearProblemExpr{iip} end
134250

251+
function NonlinearProblemExpr(sys::NonlinearSystem, args...; kwargs...)
252+
NonlinearProblemExpr{true}(sys, args...; kwargs...)
253+
end
254+
135255
function NonlinearProblemExpr{iip}(sys::NonlinearSystem,u0map,
136-
parammap=DiffEqBase.NullParameters();
137-
jac = false, sparse=false,
138-
checkbounds = false,
139-
linenumbers = false, parallel=SerialForm(),
140-
kwargs...) where iip
141-
dvs = states(sys)
142-
ps = parameters(sys)
256+
parammap=DiffEqBase.NullParameters();
257+
kwargs...) where iip
258+
259+
f, u0, p = process_NonlinearProblem(NonlinearFunctionExpr{iip}, sys, u0map, parammap; kwargs...)
260+
linenumbers = get(kwargs, :linenumbers, true)
143261

144-
f = generate_function(sys;checkbounds=checkbounds,linenumbers=linenumbers,
145-
parallel=parallel,sparse=sparse,expression=Val{true})
146-
u0 = varmap_to_vars(u0map,dvs; defaults=get_default_u0(sys))
147-
p = varmap_to_vars(parammap,ps; defaults=get_default_p(sys))
148-
quote
262+
ex = quote
149263
f = $f
150264
u0 = $u0
151265
p = $p
152-
NonlinearProblem(f,u0,p;kwargs...)
266+
NonlinearProblem(f,u0,p;$(kwargs...))
153267
end
268+
!linenumbers ? striplines(ex) : ex
154269
end

0 commit comments

Comments
 (0)