Skip to content

Commit f7a3d78

Browse files
committed
Add DAEProblemExpr and DAEFunctionExpr
1 parent f9862b7 commit f7a3d78

File tree

3 files changed

+101
-5
lines changed

3 files changed

+101
-5
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ include("structural_transformation/StructuralTransformations.jl")
137137
@reexport using .StructuralTransformations
138138

139139
export ODESystem, ODEFunction, ODEFunctionExpr, ODEProblemExpr
140+
export DAEFunctionExpr, DAEProblemExpr
140141
export SDESystem, SDEFunction, SDEFunctionExpr, SDESystemExpr
141142
export SystemStructure
142143
export JumpSystem

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ end
271271

272272
"""
273273
```julia
274-
function DiffEqBase.ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
274+
function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
275275
ps = parameters(sys);
276276
version = nothing, tgrad=false,
277277
jac = false,
@@ -384,6 +384,50 @@ function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
384384
ODEFunctionExpr{true}(sys, args...; kwargs...)
385385
end
386386

387+
"""
388+
```julia
389+
function DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
390+
ps = parameters(sys);
391+
version = nothing, tgrad=false,
392+
jac = false,
393+
sparse = false,
394+
kwargs...) where {iip}
395+
```
396+
397+
Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref).
398+
The arguments `dvs` and `ps` are used to set the order of the dependent
399+
variable and parameter vectors, respectively.
400+
"""
401+
struct DAEFunctionExpr{iip} end
402+
403+
struct DAEFunctionClosure{O, I} <: Function
404+
f_oop::O
405+
f_iip::I
406+
end
407+
(f::DAEFunctionClosure)(du, u, p, t) = f.f_oop(du, u, p, t)
408+
(f::DAEFunctionClosure)(out, du, u, p, t) = f.f_iip(out, du, u, p, t)
409+
410+
function DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
411+
ps = parameters(sys), u0 = nothing;
412+
version = nothing, tgrad=false,
413+
jac = false,
414+
linenumbers = false,
415+
sparse = false, simplify=false,
416+
kwargs...) where {iip}
417+
f_oop, f_iip = generate_function(sys, dvs, ps; expression=Val{true}, implicit_dae = true, kwargs...)
418+
fsym = gensym(:f)
419+
_f = :($fsym = $DAEFunctionClosure($f_oop, $f_iip))
420+
ex = quote
421+
$_f
422+
ODEFunction{$iip}($fsym,)
423+
end
424+
!linenumbers ? striplines(ex) : ex
425+
end
426+
427+
function DAEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
428+
DAEFunctionExpr{true}(sys, args...; kwargs...)
429+
end
430+
387431
for P in [:ODEProblem, :DAEProblem]
388432
@eval function DiffEqBase.$P(sys::AbstractODESystem, args...; kwargs...)
389433
$P{true}(sys, args...; kwargs...)
@@ -440,7 +484,7 @@ end
440484

441485
"""
442486
```julia
443-
function DiffEqBase.ODEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
487+
function ODEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
444488
parammap=DiffEqBase.NullParameters();
445489
version = nothing, tgrad=false,
446490
jac = false,
@@ -478,6 +522,53 @@ function ODEProblemExpr(sys::AbstractODESystem, args...; kwargs...)
478522
ODEProblemExpr{true}(sys, args...; kwargs...)
479523
end
480524

525+
"""
526+
```julia
527+
function DAEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
528+
parammap=DiffEqBase.NullParameters();
529+
version = nothing, tgrad=false,
530+
jac = false,
531+
checkbounds = false, sparse = false,
532+
linenumbers = true, parallel=SerialForm(),
533+
skipzeros=true, fillzeros=true,
534+
simplify=false,
535+
kwargs...) where iip
536+
```
537+
538+
Generates a Julia expression for constructing an ODEProblem from an
539+
ODESystem and allows for automatically symbolically calculating
540+
numerical enhancements.
541+
"""
542+
struct DAEProblemExpr{iip} end
543+
544+
function DAEProblemExpr{iip}(sys::AbstractODESystem,du0map,u0map,tspan,
545+
parammap=DiffEqBase.NullParameters();
546+
kwargs...) where iip
547+
f, du0, u0, p = process_DEProblem(
548+
DAEFunctionExpr{iip}, sys, u0map, parammap;
549+
implicit_dae=true, du0map=du0map, kwargs...
550+
)
551+
linenumbers = get(kwargs, :linenumbers, true)
552+
diffvars = collect_differential_variables(sys)
553+
sts = states(sys)
554+
differential_vars = map(Base.Fix2(in, diffvars), sts)
555+
556+
ex = quote
557+
f = $f
558+
u0 = $u0
559+
du0 = $du0
560+
tspan = $tspan
561+
p = $p
562+
differential_vars = $differential_vars
563+
DAEProblem{$iip}(f,du0,u0,tspan,p;differential_vars=differential_vars,$(kwargs...))
564+
end
565+
!linenumbers ? striplines(ex) : ex
566+
end
567+
568+
function DAEProblemExpr(sys::AbstractODESystem, args...; kwargs...)
569+
DAEProblemExpr{true}(sys, args...; kwargs...)
570+
end
571+
481572

482573
### Enables Steady State Problems ###
483574
function DiffEqBase.SteadyStateProblem(sys::AbstractODESystem, args...; kwargs...)

test/odesystem.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,13 @@ du0 = [
238238
D(y₃) => 0.0
239239
]
240240
prob4 = DAEProblem(sys, du0, u0, tspan, p2)
241-
@test prob4.differential_vars == [true, true, false]
242-
sol = solve(prob4, IDA())
243-
@test all(x->(sum(x), 1.0, atol=1e-12), sol.u)
241+
prob5 = eval(DAEProblemExpr(sys, du0, u0, tspan, p2))
242+
for prob in [prob4, prob5]
243+
local sol
244+
@test prob.differential_vars == [true, true, false]
245+
sol = solve(prob, IDA())
246+
@test all(x->(sum(x), 1.0, atol=1e-12), sol.u)
247+
end
244248

245249
@test ModelingToolkit.construct_state(SArray{Tuple{3,3}}(rand(3,3)), [1,2]) == SVector{2}([1, 2])
246250

0 commit comments

Comments
 (0)