Skip to content

Commit 0f69971

Browse files
Problem and Function expr construction
`probexpr = ODEProblemExpr(sys,u0,tspan,p,jac=true)` gives: ```julia quote f = begin f = ((var"##MTIIPVar#285", var"##MTKArg#281", var"##MTKArg#282", var"##MTKArg#283")->begin @inbounds begin let (xˍt, y, z, x, σ, β, ρ, t) = (var"##MTKArg#281"[1], var"##MTKArg#281"[2], var"##MTKArg#281"[3], var"##MTKArg#281"[4], var"##MTKArg#282"[1], var"##MTKArg#282"[2], var"##MTKArg#282"[3], var"##MTKArg#283") var"##MTIIPVar#285"[1] = σ * (y - x) var"##MTIIPVar#285"[2] = x * (ρ - z) - y var"##MTIIPVar#285"[3] = x * y - β * z var"##MTIIPVar#285"[4] = xˍt end end nothing end) tgrad = nothing jac = ((var"##MTIIPVar#291", var"##MTKArg#287", var"##MTKArg#288", var"##MTKArg#289")->begin @inbounds begin let (xˍt, y, z, x, σ, β, ρ, t) = (var"##MTKArg#287"[1], var"##MTKArg#287"[2], var"##MTKArg#287"[3], var"##MTKArg#287"[4], var"##MTKArg#288"[1], var"##MTKArg#288"[2], var"##MTKArg#288"[3], var"##MTKArg#289") var"##MTIIPVar#291"[1] = 0 var"##MTIIPVar#291"[2] = 0 var"##MTIIPVar#291"[3] = 0 var"##MTIIPVar#291"[4] = 1 var"##MTIIPVar#291"[5] = σ var"##MTIIPVar#291"[6] = -1 var"##MTIIPVar#291"[7] = x var"##MTIIPVar#291"[8] = 0 var"##MTIIPVar#291"[9] = 0 var"##MTIIPVar#291"[10] = -1x var"##MTIIPVar#291"[11] = -1β var"##MTIIPVar#291"[12] = 0 var"##MTIIPVar#291"[13] = -1σ var"##MTIIPVar#291"[14] = -1z + ρ var"##MTIIPVar#291"[15] = y var"##MTIIPVar#291"[16] = 0 end end nothing end) Wfact = nothing Wfact_t = nothing M = UniformScaling{Bool}(true) ODEFunction{iip}(f, jac = jac, tgrad = tgrad, Wfact = Wfact, Wfact_t = Wfact_t, mass_matrix = M, syms = [:xˍt, :y, :z, :x]) end u0 = [2.0, 0.0, 0.0, 1.0] tspan = (0.0, 100.0) p = [28.0, 2.6666666666666665, 10.0] ODEProblem{iip}(f, u0, tspan, p; ) end ```
1 parent 1402694 commit 0f69971

File tree

3 files changed

+121
-1
lines changed

3 files changed

+121
-1
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ include("systems/dependency_graphs.jl")
113113
include("latexify_recipes.jl")
114114
include("build_function.jl")
115115

116-
export ODESystem, ODEFunction
116+
export ODESystem, ODEFunction, ODEFunctionExpr, ODEProblemExpr
117117
export SDESystem, SDEFunction
118118
export JumpSystem
119119
export ODEProblem, SDEProblem, NonlinearProblem, OptimizationProblem, SteadyStateProblem

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,79 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
190190
syms = Symbol.(states(sys)))
191191
end
192192

193+
"""
194+
```julia
195+
function DiffEqBase.ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
196+
ps = parameters(sys);
197+
version = nothing, tgrad=false,
198+
jac = false, Wfact = false,
199+
sparse = false,
200+
kwargs...) where {iip}
201+
```
202+
203+
Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref).
204+
The arguments `dvs` and `ps` are used to set the order of the dependent
205+
variable and parameter vectors, respectively.
206+
"""
207+
struct ODEFunctionExpr{iip} end
208+
209+
function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
210+
ps = parameters(sys), u0 = nothing;
211+
version = nothing, tgrad=false,
212+
jac = false, Wfact = false,
213+
sparse = false,
214+
kwargs...) where {iip}
215+
216+
idx = iip ? 2 : 1
217+
f = generate_function(sys, dvs, ps; expression=Val{true}, kwargs...)[idx]
218+
if tgrad
219+
_tgrad = generate_tgrad(sys, dvs, ps; expression=Val{true}, kwargs...)[idx]
220+
else
221+
_tgrad = :nothing
222+
end
223+
224+
if jac
225+
_jac = generate_jacobian(sys, dvs, ps; sparse = sparse, expression=Val{true}, kwargs...)[idx]
226+
else
227+
_jac = :nothing
228+
end
229+
230+
if Wfact
231+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps; expression=Val{true}, kwargs...)
232+
_Wfact = tmp_Wfact[idx]
233+
_Wfact_t = tmp_Wfact_t[idx]
234+
else
235+
_Wfact,_Wfact_t = :nothing,:nothing
236+
end
237+
238+
M = calculate_massmatrix(sys)
239+
240+
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
241+
242+
quote
243+
f = $f
244+
tgrad = $_tgrad
245+
jac = $_jac
246+
Wfact = $_Wfact
247+
Wfact_t = $_Wfact_t
248+
M = $_M
249+
250+
ODEFunction{iip}(f,
251+
jac = jac,
252+
tgrad = tgrad,
253+
Wfact = Wfact,
254+
Wfact_t = Wfact_t,
255+
mass_matrix = M,
256+
syms = $(Symbol.(states(sys))))
257+
end
258+
end
259+
260+
261+
function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
262+
ODEFunctionExpr{true}(sys, args...; kwargs...)
263+
end
264+
265+
193266
function DiffEqBase.ODEProblem(sys::AbstractODESystem, args...; kwargs...)
194267
ODEProblem{true}(sys, args...; kwargs...)
195268
end
@@ -225,6 +298,50 @@ function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
225298
ODEProblem{iip}(f,u0,tspan,p;kwargs...)
226299
end
227300

301+
"""
302+
```julia
303+
function DiffEqBase.ODEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
304+
parammap=DiffEqBase.NullParameters();
305+
version = nothing, tgrad=false,
306+
jac = false, Wfact = false,
307+
checkbounds = false, sparse = false,
308+
linenumbers = true, parallel=SerialForm(),
309+
kwargs...) where iip
310+
```
311+
312+
Generates a Julia expression for constructing an ODEProblem from an
313+
ODESystem and allows for automatically symbolically calculating
314+
numerical enhancements.
315+
"""
316+
struct ODEProblemExpr{iip} end
317+
318+
function ODEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
319+
parammap=DiffEqBase.NullParameters();
320+
version = nothing, tgrad=false,
321+
jac = false, Wfact = false,
322+
checkbounds = false, sparse = false,
323+
linenumbers = true, parallel=SerialForm(),
324+
kwargs...) where iip
325+
dvs = states(sys)
326+
ps = parameters(sys)
327+
u0 = varmap_to_vars(u0map,dvs)
328+
p = varmap_to_vars(parammap,ps)
329+
f = ODEFunctionExpr{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
330+
linenumbers=linenumbers,parallel=parallel,
331+
sparse=sparse)
332+
quote
333+
f = $f
334+
u0 = $u0
335+
tspan = $tspan
336+
p = $p
337+
ODEProblem{iip}(f,u0,tspan,p;$(kwargs...))
338+
end
339+
end
340+
341+
function ODEProblemExpr(sys::AbstractODESystem, args...; kwargs...)
342+
ODEProblemExpr{true}(sys, args...; kwargs...)
343+
end
344+
228345

229346
### Enables Steady State Problems ###
230347
function DiffEqBase.SteadyStateProblem(sys::AbstractODESystem, args...; kwargs...)

test/lowering_solving.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ p = [σ => 28.0,
3131

3232
tspan = (0.0,100.0)
3333
prob = ODEProblem(sys,u0,tspan,p,jac=true)
34+
probexpr = ODEProblemExpr(sys,u0,tspan,p,jac=true)
3435
sol = solve(prob,Tsit5())
36+
solexpr = solve(eval(prob),Tsit5())
37+
@test all(x->x==0,Array(sol - solexpr))
3538
#using Plots; plot(sol,vars=(:x,:y))
3639

3740
@parameters t σ ρ β

0 commit comments

Comments
 (0)