Skip to content

Commit b2ce1b3

Browse files
Merge pull request #897 from SciML/myb/cleanup
Add DAEProblem overload for `ODESystem`s
2 parents 11a6836 + d67b8ab commit b2ce1b3

File tree

7 files changed

+252
-23
lines changed

7 files changed

+252
-23
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
8181
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
8282
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
8383
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
84+
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
8485
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8586

8687
[targets]
87-
test = ["BenchmarkTools", "ForwardDiff", "GalacticOptim", "NonlinearSolve", "OrdinaryDiffEq", "Optim", "Random", "SteadyStateDiffEq", "Test", "StochasticDiffEq"]
88+
test = ["BenchmarkTools", "ForwardDiff", "GalacticOptim", "NonlinearSolve", "OrdinaryDiffEq", "Optim", "Random", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ include("structural_transformation/StructuralTransformations.jl")
137137
@reexport using .StructuralTransformations
138138

139139
export ODESystem, ODEFunction, ODEFunctionExpr, ODEProblemExpr
140+
export DAEFunctionExpr, DAEProblemExpr
140141
export SDESystem, SDEFunction, SDEFunctionExpr, SDESystemExpr
141142
export SystemStructure
142143
export JumpSystem

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 212 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,23 +61,34 @@ function check_derivative_variables(eq, expr=eq.rhs)
6161
foreach(Base.Fix1(check_derivative_variables, eq), arguments(expr))
6262
end
6363

64-
function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
64+
function generate_function(
65+
sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
66+
implicit_dae=false,
67+
ddvs=implicit_dae ? map(Differential(independent_variable(sys)), dvs) : nothing,
68+
kwargs...
69+
)
6570
# optimization
6671
#obsvars = map(eq->eq.lhs, observed(sys))
6772
#fulldvs = [dvs; obsvars]
6873

6974
eqs = equations(sys)
7075
foreach(check_derivative_variables, eqs)
7176
# substitute x(t) by just x
72-
rhss = [deq.rhs for deq in eqs]
77+
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
78+
[eq.rhs for eq in eqs]
7379
#obss = [makesym(value(eq.lhs)) ~ substitute(eq.rhs, sub) for eq ∈ observed(sys)]
7480
#rhss = Let(obss, rhss)
7581

7682
# TODO: add an optional check on the ordering of observed equations
77-
return build_function(rhss,
78-
map(x->time_varying_as_func(value(x), sys), dvs),
79-
map(x->time_varying_as_func(value(x), sys), ps),
80-
get_iv(sys); kwargs...)
83+
u = map(x->time_varying_as_func(value(x), sys), dvs)
84+
p = map(x->time_varying_as_func(value(x), sys), ps)
85+
t = get_iv(sys)
86+
87+
if implicit_dae
88+
build_function(rhss, ddvs, u, p, t; kwargs...)
89+
else
90+
build_function(rhss, u, p, t; kwargs...)
91+
end
8192
end
8293

8394
function time_varying_as_func(x, sys)
@@ -120,8 +131,10 @@ function isautonomous(sys::AbstractODESystem)
120131
all(iszero,tgrad)
121132
end
122133

123-
function DiffEqBase.ODEFunction(sys::AbstractODESystem, args...; kwargs...)
124-
ODEFunction{true}(sys, args...; kwargs...)
134+
for F in [:ODEFunction, :DAEFunction]
135+
@eval function DiffEqBase.$F(sys::AbstractODESystem, args...; kwargs...)
136+
$F{true}(sys, args...; kwargs...)
137+
end
125138
end
126139

127140
"""
@@ -201,7 +214,64 @@ end
201214

202215
"""
203216
```julia
204-
function DiffEqBase.ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
217+
function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
218+
ps = parameters(sys);
219+
version = nothing, tgrad=false,
220+
jac = false,
221+
sparse = false,
222+
kwargs...) where {iip}
223+
```
224+
225+
Create an `DAEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and
226+
`ps` are used to set the order of the dependent variable and parameter vectors,
227+
respectively.
228+
"""
229+
function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
230+
ps = parameters(sys), u0 = nothing;
231+
ddvs=map(diff2term Differential(independent_variable(sys)), dvs),
232+
version = nothing,
233+
#=
234+
tgrad=false,
235+
jac = false,
236+
sparse = false,
237+
=#
238+
simplify=false,
239+
eval_expression = true,
240+
eval_module = @__MODULE__,
241+
kwargs...) where {iip}
242+
243+
f_gen = generate_function(sys, dvs, ps; implicit_dae = true, expression=Val{eval_expression}, expression_module=eval_module, kwargs...)
244+
f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen) : f_gen
245+
f(du,u,p,t) = f_oop(du,u,p,t)
246+
f(out,du,u,p,t) = f_iip(out,du,u,p,t)
247+
248+
# TODO: Jacobian sparsity / sparse Jacobian / dense Jacobian
249+
250+
#=
251+
observedfun = let sys = sys, dict = Dict()
252+
# TODO: We don't have enought information to reconstruct arbitrary state
253+
# in general from `(u, p, t)`, e.g. `a ~ D(x)`.
254+
function generated_observed(obsvar, u, p, t)
255+
obs = get!(dict, value(obsvar)) do
256+
build_explicit_observed_function(sys, obsvar)
257+
end
258+
obs(u, p, t)
259+
end
260+
end
261+
=#
262+
263+
DAEFunction{iip}(
264+
f,
265+
syms = Symbol.(dvs),
266+
# missing fields in `DAEFunction`
267+
#indepsym = Symbol(independent_variable(sys)),
268+
#observed = observedfun,
269+
)
270+
end
271+
272+
"""
273+
```julia
274+
function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
205275
ps = parameters(sys);
206276
version = nothing, tgrad=false,
207277
jac = false,
@@ -277,6 +347,7 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
277347
end
278348

279349
function process_DEProblem(constructor, sys::AbstractODESystem,u0map,parammap;
350+
implicit_dae = false, du0map = nothing,
280351
version = nothing, tgrad=false,
281352
jac = false,
282353
checkbounds = false, sparse = false,
@@ -287,27 +358,80 @@ function process_DEProblem(constructor, sys::AbstractODESystem,u0map,parammap;
287358
dvs = states(sys)
288359
ps = parameters(sys)
289360
defs = defaults(sys)
361+
iv = independent_variable(sys)
290362

291363
u0 = varmap_to_vars(u0map,dvs; defaults=defs)
364+
if implicit_dae && du0map !== nothing
365+
ddvs = map(Differential(iv), dvs)
366+
du0 = varmap_to_vars(du0map, ddvs; defaults=defaults, toterm=identity)
367+
else
368+
du0 = nothing
369+
ddvs = nothing
370+
end
292371
p = varmap_to_vars(parammap,ps; defaults=defs)
293372

294373
if u0 !== nothing
295374
length(dvs) == length(u0) || throw(ArgumentError("States ($(length(dvs))) and initial conditions ($(length(u0))) are of different lengths."))
296375
end
297376

298-
f = constructor(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,checkbounds=checkbounds,
377+
f = constructor(sys,dvs,ps,u0;ddvs=ddvs,tgrad=tgrad,jac=jac,checkbounds=checkbounds,
299378
linenumbers=linenumbers,parallel=parallel,simplify=simplify,
300379
sparse=sparse,eval_expression=eval_expression,kwargs...)
301-
return f, u0, p
380+
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
302381
end
303382

304383
function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
305384
ODEFunctionExpr{true}(sys, args...; kwargs...)
306385
end
307386

387+
"""
388+
```julia
389+
function DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
390+
ps = parameters(sys);
391+
version = nothing, tgrad=false,
392+
jac = false,
393+
sparse = false,
394+
kwargs...) where {iip}
395+
```
396+
397+
Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref).
398+
The arguments `dvs` and `ps` are used to set the order of the dependent
399+
variable and parameter vectors, respectively.
400+
"""
401+
struct DAEFunctionExpr{iip} end
308402

309-
function DiffEqBase.ODEProblem(sys::AbstractODESystem, args...; kwargs...)
310-
ODEProblem{true}(sys, args...; kwargs...)
403+
struct DAEFunctionClosure{O, I} <: Function
404+
f_oop::O
405+
f_iip::I
406+
end
407+
(f::DAEFunctionClosure)(du, u, p, t) = f.f_oop(du, u, p, t)
408+
(f::DAEFunctionClosure)(out, du, u, p, t) = f.f_iip(out, du, u, p, t)
409+
410+
function DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
411+
ps = parameters(sys), u0 = nothing;
412+
version = nothing, tgrad=false,
413+
jac = false,
414+
linenumbers = false,
415+
sparse = false, simplify=false,
416+
kwargs...) where {iip}
417+
f_oop, f_iip = generate_function(sys, dvs, ps; expression=Val{true}, implicit_dae = true, kwargs...)
418+
fsym = gensym(:f)
419+
_f = :($fsym = $DAEFunctionClosure($f_oop, $f_iip))
420+
ex = quote
421+
$_f
422+
ODEFunction{$iip}($fsym,)
423+
end
424+
!linenumbers ? striplines(ex) : ex
425+
end
426+
427+
function DAEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
428+
DAEFunctionExpr{true}(sys, args...; kwargs...)
429+
end
430+
431+
for P in [:ODEProblem, :DAEProblem]
432+
@eval function DiffEqBase.$P(sys::AbstractODESystem, args...; kwargs...)
433+
$P{true}(sys, args...; kwargs...)
434+
end
311435
end
312436

313437
"""
@@ -333,7 +457,34 @@ end
333457

334458
"""
335459
```julia
336-
function DiffEqBase.ODEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
460+
function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
461+
parammap=DiffEqBase.NullParameters();
462+
version = nothing, tgrad=false,
463+
jac = false,
464+
checkbounds = false, sparse = false,
465+
simplify=false,
466+
linenumbers = true, parallel=SerialForm(),
467+
kwargs...) where iip
468+
```
469+
470+
Generates an DAEProblem from an ODESystem and allows for automatically
471+
symbolically calculating numerical enhancements.
472+
"""
473+
function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem,du0map,u0map,tspan,
474+
parammap=DiffEqBase.NullParameters();kwargs...) where iip
475+
f, du0, u0, p = process_DEProblem(
476+
DAEFunction{iip}, sys, u0map, parammap;
477+
implicit_dae=true, du0map=du0map, kwargs...
478+
)
479+
diffvars = collect_differential_variables(sys)
480+
sts = states(sys)
481+
differential_vars = map(Base.Fix2(in, diffvars), sts)
482+
DAEProblem{iip}(f,du0,u0,tspan,p;differential_vars=differential_vars,kwargs...)
483+
end
484+
485+
"""
486+
```julia
487+
function ODEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
337488
parammap=DiffEqBase.NullParameters();
338489
version = nothing, tgrad=false,
339490
jac = false,
@@ -371,6 +522,53 @@ function ODEProblemExpr(sys::AbstractODESystem, args...; kwargs...)
371522
ODEProblemExpr{true}(sys, args...; kwargs...)
372523
end
373524

525+
"""
526+
```julia
527+
function DAEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
528+
parammap=DiffEqBase.NullParameters();
529+
version = nothing, tgrad=false,
530+
jac = false,
531+
checkbounds = false, sparse = false,
532+
linenumbers = true, parallel=SerialForm(),
533+
skipzeros=true, fillzeros=true,
534+
simplify=false,
535+
kwargs...) where iip
536+
```
537+
538+
Generates a Julia expression for constructing an ODEProblem from an
539+
ODESystem and allows for automatically symbolically calculating
540+
numerical enhancements.
541+
"""
542+
struct DAEProblemExpr{iip} end
543+
544+
function DAEProblemExpr{iip}(sys::AbstractODESystem,du0map,u0map,tspan,
545+
parammap=DiffEqBase.NullParameters();
546+
kwargs...) where iip
547+
f, du0, u0, p = process_DEProblem(
548+
DAEFunctionExpr{iip}, sys, u0map, parammap;
549+
implicit_dae=true, du0map=du0map, kwargs...
550+
)
551+
linenumbers = get(kwargs, :linenumbers, true)
552+
diffvars = collect_differential_variables(sys)
553+
sts = states(sys)
554+
differential_vars = map(Base.Fix2(in, diffvars), sts)
555+
556+
ex = quote
557+
f = $f
558+
u0 = $u0
559+
du0 = $du0
560+
tspan = $tspan
561+
p = $p
562+
differential_vars = $differential_vars
563+
DAEProblem{$iip}(f,du0,u0,tspan,p;differential_vars=differential_vars,$(kwargs...))
564+
end
565+
!linenumbers ? striplines(ex) : ex
566+
end
567+
568+
function DAEProblemExpr(sys::AbstractODESystem, args...; kwargs...)
569+
DAEProblemExpr{true}(sys, args...; kwargs...)
570+
end
571+
374572

375573
### Enables Steady State Problems ###
376574
function DiffEqBase.SteadyStateProblem(sys::AbstractODESystem, args...; kwargs...)

src/systems/diffeqs/odesystem.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,18 @@ function _eq_unordered(a, b)
280280
end
281281
return true
282282
end
283+
284+
function collect_differential_variables(sys::ODESystem)
285+
eqs = equations(sys)
286+
vars = Set()
287+
diffvars = Set()
288+
for eq in eqs
289+
vars!(vars, eq)
290+
for v in vars
291+
isdifferential(v) || continue
292+
push!(diffvars, arguments(v)[1])
293+
end
294+
empty!(vars)
295+
end
296+
return diffvars
297+
end

src/variables.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Takes a list of pairs of `variables=>values` and an ordered list of variables
1010
and creates the array of values in the correct order with default values when
1111
applicable.
1212
"""
13-
function varmap_to_vars(varmap, varlist; defaults=Dict(), check=true)
13+
function varmap_to_vars(varmap, varlist; defaults=Dict(), check=true, toterm=Symbolics.diff2term)
1414
# Edge cases where one of the arguments is effectively empty.
1515
is_incomplete_initialization = varmap isa DiffEqBase.NullParameters || varmap === nothing
1616
if is_incomplete_initialization || isempty(varmap)
@@ -31,7 +31,7 @@ function varmap_to_vars(varmap, varlist; defaults=Dict(), check=true)
3131
if eltype(varmap) <: Pair # `varmap` is a dict or an array of pairs
3232
varmap = todict(varmap)
3333
rules = Dict(varmap)
34-
vals = _varmap_to_vars(varmap, varlist; defaults=defaults, check=check)
34+
vals = _varmap_to_vars(varmap, varlist; defaults=defaults, check=check, toterm=toterm)
3535
else # plain array-like initialization
3636
vals = varmap
3737
end
@@ -45,9 +45,9 @@ function varmap_to_vars(varmap, varlist; defaults=Dict(), check=true)
4545
end
4646
end
4747

48-
function _varmap_to_vars(varmap::Dict, varlist; defaults=Dict(), check=false)
48+
function _varmap_to_vars(varmap::Dict, varlist; defaults=Dict(), check=false, toterm=Symbolics.diff2term)
4949
varmap = merge(defaults, varmap) # prefers the `varmap`
50-
varmap = Dict(Symbolics.diff2term(value(k))=>value(varmap[k]) for k in keys(varmap))
50+
varmap = Dict(toterm(value(k))=>value(varmap[k]) for k in keys(varmap))
5151
# resolve symbolic parameter expressions
5252
for (p, v) in pairs(varmap)
5353
varmap[p] = fixpoint_sub(v, varmap)

0 commit comments

Comments
 (0)