Skip to content

Commit e8a5739

Browse files
add other system expr outputs
1 parent 8b65d4e commit e8a5739

File tree

4 files changed

+260
-7
lines changed

4 files changed

+260
-7
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref).
204204
The arguments `dvs` and `ps` are used to set the order of the dependent
205205
variable and parameter vectors, respectively.
206206
"""
207-
struct ODEFunctionExpr{iip} end
207+
struct ODEFunctionExpr end
208208

209209
function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
210210
ps = parameters(sys), u0 = nothing;
@@ -314,7 +314,7 @@ Generates a Julia expression for constructing an ODEProblem from an
314314
ODESystem and allows for automatically symbolically calculating
315315
numerical enhancements.
316316
"""
317-
struct ODEProblemExpr{iip} end
317+
struct ODEProblemExpr end
318318

319319
function ODEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
320320
parammap=DiffEqBase.NullParameters();
@@ -335,7 +335,7 @@ function ODEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
335335
u0 = $u0
336336
tspan = $tspan
337337
p = $p
338-
ODEProblem{iip}(f,u0,tspan,p;$(kwargs...))
338+
ODEProblem(f,u0,tspan,p;$(kwargs...))
339339
end
340340
!linenumbers ? striplines(ex) : ex
341341
end
@@ -363,7 +363,7 @@ function DiffEqBase.SteadyStateProblem(sys::AbstractODESystem,u0map,tspan,
363363
Generates an SteadyStateProblem from an ODESystem and allows for automatically
364364
symbolically calculating numerical enhancements.
365365
"""
366-
function DiffEqBase.SteadyStateProblem(sys::AbstractODESystem,u0map,
366+
function DiffEqBase.SteadyStateProblem{iip}(sys::AbstractODESystem,u0map,
367367
parammap=DiffEqBase.NullParameters();
368368
version = nothing, tgrad=false,
369369
jac = false, Wfact = false,
@@ -379,3 +379,46 @@ function DiffEqBase.SteadyStateProblem(sys::AbstractODESystem,u0map,
379379
sparse=sparse)
380380
SteadyStateProblem(f,u0,p;kwargs...)
381381
end
382+
383+
struct SteadyStateProblemExpr end
384+
385+
"""
386+
```julia
387+
function DiffEqBase.SteadyStateProblem(sys::AbstractODESystem,u0map,tspan,
388+
parammap=DiffEqBase.NullParameters();
389+
version = nothing, tgrad=false,
390+
jac = false, Wfact = false,
391+
checkbounds = false, sparse = false,
392+
linenumbers = true, parallel=SerialForm(),
393+
kwargs...) where iip
394+
```
395+
Generates an SteadyStateProblem from an ODESystem and allows for automatically
396+
symbolically calculating numerical enhancements.
397+
"""
398+
function DiffEqBase.SteadyStateProblem{iip}(sys::AbstractODESystem,u0map,
399+
parammap=DiffEqBase.NullParameters();
400+
version = nothing, tgrad=false,
401+
jac = false, Wfact = false,
402+
checkbounds = false, sparse = false,
403+
linenumbers = true, parallel=SerialForm(),
404+
kwargs...) where iip
405+
dvs = states(sys)
406+
ps = parameters(sys)
407+
u0 = varmap_to_vars(u0map,dvs)
408+
p = varmap_to_vars(parammap,ps)
409+
f = ODEFunctionExpr(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
410+
linenumbers=linenumbers,parallel=parallel,
411+
sparse=sparse)
412+
ex = quote
413+
f = $f
414+
u0 = $u0
415+
tspan = $tspan
416+
p = $p
417+
SteadyStateProblem(f,u0,tspan,p;$(kwargs...))
418+
end
419+
!linenumbers ? striplines(ex) : ex
420+
end
421+
422+
function DiffEqBase.SteadyStateProblemExpr(sys::AbstractODESystem, args...; kwargs...)
423+
SteadyStateProblemExpr{true}(sys, args...; kwargs...)
424+
end

src/systems/diffeqs/sdesystem.jl

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,83 @@ function DiffEqBase.SDEFunction(sys::SDESystem, args...; kwargs...)
154154
SDEFunction{true}(sys, args...; kwargs...)
155155
end
156156

157+
"""
158+
```julia
159+
function DiffEqBase.SDEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
160+
ps = parameters(sys);
161+
version = nothing, tgrad=false,
162+
jac = false, Wfact = false,
163+
sparse = false,
164+
kwargs...) where {iip}
165+
```
166+
167+
Create a Julia expression for an `SDEFunction` from the [`SDESystem`](@ref).
168+
The arguments `dvs` and `ps` are used to set the order of the dependent
169+
variable and parameter vectors, respectively.
170+
"""
171+
struct SDEFunctionExpr end
172+
173+
function SDEFunctionExpr{iip}(sys::SDESystem, dvs = states(sys),
174+
ps = parameters(sys), u0 = nothing;
175+
version = nothing, tgrad=false,
176+
jac = false, Wfact = false,
177+
sparse = false,linenumbers = false,
178+
kwargs...) where {iip}
179+
180+
idx = iip ? 2 : 1
181+
f = generate_function(sys, dvs, ps; expression=Val{true}, kwargs...)[idx]
182+
g = generate_diffusion_function(sys, dvs, ps; expression=Val{true}, kwargs...)[idx]
183+
if tgrad
184+
_tgrad = generate_tgrad(sys, dvs, ps; expression=Val{true}, kwargs...)[idx]
185+
else
186+
_tgrad = :nothing
187+
end
188+
189+
if jac
190+
_jac = generate_jacobian(sys, dvs, ps; sparse = sparse, expression=Val{true}, kwargs...)[idx]
191+
else
192+
_jac = :nothing
193+
end
194+
195+
if Wfact
196+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps; expression=Val{true}, kwargs...)
197+
_Wfact = tmp_Wfact[idx]
198+
_Wfact_t = tmp_Wfact_t[idx]
199+
else
200+
_Wfact,_Wfact_t = :nothing,:nothing
201+
end
202+
203+
M = calculate_massmatrix(sys)
204+
205+
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
206+
207+
ex = quote
208+
f = $f
209+
g = $g
210+
tgrad = $_tgrad
211+
jac = $_jac
212+
Wfact = $_Wfact
213+
Wfact_t = $_Wfact_t
214+
M = $_M
215+
216+
SDEFunction{iip}(f,g,
217+
jac = jac,
218+
tgrad = tgrad,
219+
Wfact = Wfact,
220+
Wfact_t = Wfact_t,
221+
mass_matrix = M,
222+
syms = $(Symbol.(states(sys))))
223+
end
224+
!linenumbers ? striplines(ex) : ex
225+
end
226+
227+
228+
function SDEFunctionExpr(sys::SDESystem, args...; kwargs...)
229+
SDEFunctionExpr{true}(sys, args...; kwargs...)
230+
end
231+
157232
function rename(sys::SDESystem,name)
158-
ODESystem(sys.eqs, sys.noiseeqs, sys.iv, sys.states, sys.ps, sys.tgrad, sys.jac, sys.Wfact, sys.Wfact_t, name, sys.systems)
233+
SDESystem(sys.eqs, sys.noiseeqs, sys.iv, sys.states, sys.ps, sys.tgrad, sys.jac, sys.Wfact, sys.Wfact_t, name, sys.systems)
159234
end
160235

161236
"""
@@ -203,3 +278,58 @@ end
203278
function DiffEqBase.SDEProblem(sys::SDESystem, args...; kwargs...)
204279
SDEProblem{true}(sys, args...; kwargs...)
205280
end
281+
282+
"""
283+
```julia
284+
function DiffEqBase.SDEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
285+
parammap=DiffEqBase.NullParameters();
286+
version = nothing, tgrad=false,
287+
jac = false, Wfact = false,
288+
checkbounds = false, sparse = false,
289+
linenumbers = true, parallel=SerialForm(),
290+
kwargs...) where iip
291+
```
292+
293+
Generates a Julia expression for constructing an ODEProblem from an
294+
ODESystem and allows for automatically symbolically calculating
295+
numerical enhancements.
296+
"""
297+
struct SDEProblemExpr end
298+
299+
function SDEProblemExpr{iip}(sys::SDESystem,u0map,tspan,
300+
parammap=DiffEqBase.NullParameters();
301+
version = nothing, tgrad=false,
302+
jac = false, Wfact = false,
303+
checkbounds = false, sparse = false,
304+
linenumbers = false, parallel=SerialForm(),
305+
kwargs...) where iip
306+
dvs = states(sys)
307+
ps = parameters(sys)
308+
u0 = varmap_to_vars(u0map,dvs)
309+
p = varmap_to_vars(parammap,ps)
310+
f = SDEFunctionExpr{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,
311+
Wfact=Wfact,checkbounds=checkbounds,
312+
linenumbers=linenumbers,parallel=parallel,
313+
sparse=sparse)
314+
if typeof(sys.noiseeqs) <: AbstractVector
315+
noise_rate_prototype = nothing
316+
elseif sparsenoise
317+
I,J,V = findnz(SparseArrays.sparse(sys.noiseeqs))
318+
noise_rate_prototype = SparseArrays.sparse(I,J,zero(eltype(u0)))
319+
else
320+
noise_rate_prototype = zeros(eltype(u0),size(sys.noiseeqs))
321+
end
322+
ex = quote
323+
f = $f
324+
u0 = $u0
325+
tspan = $tspan
326+
p = $p
327+
noise_rate_prototype = $noise_rate_prototype
328+
SDEProblem(f,f.g,u0,tspan,p;noise_rate_prototype=noise_rate_prototype,$(kwargs...))
329+
end
330+
!linenumbers ? striplines(ex) : ex
331+
end
332+
333+
function SDEProblemExpr(sys::SDESystem, args...; kwargs...)
334+
SDEProblemExpr{true}(sys, args...; kwargs...)
335+
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,tspan,
7878
Generates an NonlinearProblem from a NonlinearSystem and allows for automatically
7979
symbolically calculating numerical enhancements.
8080
"""
81-
function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,tspan,
81+
function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,
8282
parammap=DiffEqBase.NullParameters();
8383
jac = false, sparse=false,
8484
checkbounds = false,
@@ -91,5 +91,42 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,tspan,
9191
parallel=parallel,sparse=sparse,expression=Val{false})
9292
u0 = varmap_to_vars(u0map,dvs)
9393
p = varmap_to_vars(parammap,ps)
94-
NonlinearProblem(f,u0,tspan,p;kwargs...)
94+
NonlinearProblem(f,u0,p;kwargs...)
95+
end
96+
97+
struct NonlinearProblemExpr end
98+
99+
"""
100+
```julia
101+
function DiffEqBase.NonlinearProblemExpr{iip}(sys::NonlinearSystem,u0map,tspan,
102+
parammap=DiffEqBase.NullParameters();
103+
jac = false, sparse=false,
104+
checkbounds = false,
105+
linenumbers = true, parallel=SerialForm(),
106+
kwargs...) where iip
107+
```
108+
109+
Generates a Julia expression for a NonlinearProblem from a
110+
NonlinearSystem and allows for automatically symbolically calculating
111+
numerical enhancements.
112+
"""
113+
function DiffEqBase.NonlinearProblemExpr{iip}(sys::NonlinearSystem,u0map,tspan,
114+
parammap=DiffEqBase.NullParameters();
115+
jac = false, sparse=false,
116+
checkbounds = false,
117+
linenumbers = false, parallel=SerialForm(),
118+
kwargs...) where iip
119+
dvs = states(sys)
120+
ps = parameters(sys)
121+
122+
f = generate_function(sys;checkbounds=checkbounds,linenumbers=linenumbers,
123+
parallel=parallel,sparse=sparse,expression=Val{true})
124+
u0 = varmap_to_vars(u0map,dvs)
125+
p = varmap_to_vars(parammap,ps)
126+
quote
127+
f = $f
128+
u0 = $u0
129+
p = $p
130+
NonlinearProblem(f,u0,p;kwargs...)
131+
end
95132
end

src/systems/optimization/optimizationsystem.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,46 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,
105105
ub = varmap_to_vars(ub,dvs)
106106
OptimizationProblem(f,p;u0=u0,lb=lb,ub=ub,kwargs...)
107107
end
108+
109+
struct OptimizationProblemExpr end
110+
111+
"""
112+
```julia
113+
function DiffEqBase.OptimizationProblemExpr{iip}(sys::OptimizationSystem,
114+
parammap=DiffEqBase.NullParameters();
115+
u0=nothing, lb=nothing, ub=nothing,
116+
hes = false, sparse = false,
117+
checkbounds = false,
118+
linenumbers = true, parallel=SerialForm(),
119+
kwargs...) where iip
120+
```
121+
122+
Generates a Julia expression for an OptimizationProblem from an
123+
OptimizationSystem and allows for automatically symbolically
124+
calculating numerical enhancements.
125+
"""
126+
function DiffEqBase.OptimizationProblemExpr{iip}(sys::OptimizationSystem,
127+
parammap=DiffEqBase.NullParameters();
128+
u0=nothing, lb=nothing, ub=nothing,
129+
hes = false, sparse = false,
130+
checkbounds = false,
131+
linenumbers = false, parallel=SerialForm(),
132+
kwargs...) where iip
133+
dvs = states(sys)
134+
ps = parameters(sys)
135+
136+
f = generate_function(sys,checkbounds=checkbounds,linenumbers=linenumbers,
137+
parallel=parallel,expression=Val{true})
138+
u0 = varmap_to_vars(u0,dvs)
139+
p = varmap_to_vars(parammap,ps)
140+
lb = varmap_to_vars(lb,dvs)
141+
ub = varmap_to_vars(ub,dvs)
142+
quote
143+
f = $f
144+
p = $p
145+
u0 = $u0
146+
lb = $lb
147+
ub = $ub
148+
OptimizationProblem(f,p;u0=u0,lb=lb,ub=ub,kwargs...)
149+
end
150+
end

0 commit comments

Comments
 (0)