@@ -61,23 +61,34 @@ function check_derivative_variables(eq, expr=eq.rhs)
61
61
foreach (Base. Fix1 (check_derivative_variables, eq), arguments (expr))
62
62
end
63
63
64
- function generate_function (sys:: AbstractODESystem , dvs = states (sys), ps = parameters (sys); kwargs... )
64
+ function generate_function (
65
+ sys:: AbstractODESystem , dvs = states (sys), ps = parameters (sys);
66
+ implicit_dae= false ,
67
+ ddvs= implicit_dae ? map (Differential (independent_variable (sys)), dvs) : nothing ,
68
+ kwargs...
69
+ )
65
70
# optimization
66
71
# obsvars = map(eq->eq.lhs, observed(sys))
67
72
# fulldvs = [dvs; obsvars]
68
73
69
74
eqs = equations (sys)
70
75
foreach (check_derivative_variables, eqs)
71
76
# substitute x(t) by just x
72
- rhss = [deq. rhs for deq in eqs]
77
+ rhss = implicit_dae ? [_iszero (eq. lhs) ? eq. rhs : eq. rhs - eq. lhs for eq in eqs] :
78
+ [eq. rhs for eq in eqs]
73
79
# obss = [makesym(value(eq.lhs)) ~ substitute(eq.rhs, sub) for eq ∈ observed(sys)]
74
80
# rhss = Let(obss, rhss)
75
81
76
82
# TODO : add an optional check on the ordering of observed equations
77
- return build_function (rhss,
78
- map (x-> time_varying_as_func (value (x), sys), dvs),
79
- map (x-> time_varying_as_func (value (x), sys), ps),
80
- get_iv (sys); kwargs... )
83
+ u = map (x-> time_varying_as_func (value (x), sys), dvs)
84
+ p = map (x-> time_varying_as_func (value (x), sys), ps)
85
+ t = get_iv (sys)
86
+
87
+ if implicit_dae
88
+ build_function (rhss, ddvs, u, p, t; kwargs... )
89
+ else
90
+ build_function (rhss, u, p, t; kwargs... )
91
+ end
81
92
end
82
93
83
94
function time_varying_as_func (x, sys)
@@ -120,8 +131,10 @@ function isautonomous(sys::AbstractODESystem)
120
131
all (iszero,tgrad)
121
132
end
122
133
123
- function DiffEqBase. ODEFunction (sys:: AbstractODESystem , args... ; kwargs... )
124
- ODEFunction {true} (sys, args... ; kwargs... )
134
+ for F in [:ODEFunction , :DAEFunction ]
135
+ @eval function DiffEqBase. $F (sys:: AbstractODESystem , args... ; kwargs... )
136
+ $ F {true} (sys, args... ; kwargs... )
137
+ end
125
138
end
126
139
127
140
"""
201
214
202
215
"""
203
216
```julia
204
- function DiffEqBase.ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
217
+ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
218
+ ps = parameters(sys);
219
+ version = nothing, tgrad=false,
220
+ jac = false,
221
+ sparse = false,
222
+ kwargs...) where {iip}
223
+ ```
224
+
225
+ Create an `DAEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and
226
+ `ps` are used to set the order of the dependent variable and parameter vectors,
227
+ respectively.
228
+ """
229
+ function DiffEqBase. DAEFunction {iip} (sys:: AbstractODESystem , dvs = states (sys),
230
+ ps = parameters (sys), u0 = nothing ;
231
+ ddvs= map (diff2term ∘ Differential (independent_variable (sys)), dvs),
232
+ version = nothing ,
233
+ #=
234
+ tgrad=false,
235
+ jac = false,
236
+ sparse = false,
237
+ =#
238
+ simplify= false ,
239
+ eval_expression = true ,
240
+ eval_module = @__MODULE__ ,
241
+ kwargs... ) where {iip}
242
+
243
+ f_gen = generate_function (sys, dvs, ps; implicit_dae = true , expression= Val{eval_expression}, expression_module= eval_module, kwargs... )
244
+ f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction (eval_module, ex) for ex in f_gen) : f_gen
245
+ f (du,u,p,t) = f_oop (du,u,p,t)
246
+ f (out,du,u,p,t) = f_iip (out,du,u,p,t)
247
+
248
+ # TODO : Jacobian sparsity / sparse Jacobian / dense Jacobian
249
+
250
+ #=
251
+ observedfun = let sys = sys, dict = Dict()
252
+ # TODO : We don't have enought information to reconstruct arbitrary state
253
+ # in general from `(u, p, t)`, e.g. `a ~ D(x)`.
254
+ function generated_observed(obsvar, u, p, t)
255
+ obs = get!(dict, value(obsvar)) do
256
+ build_explicit_observed_function(sys, obsvar)
257
+ end
258
+ obs(u, p, t)
259
+ end
260
+ end
261
+ =#
262
+
263
+ DAEFunction {iip} (
264
+ f,
265
+ syms = Symbol .(dvs),
266
+ # missing fields in `DAEFunction`
267
+ # indepsym = Symbol(independent_variable(sys)),
268
+ # observed = observedfun,
269
+ )
270
+ end
271
+
272
+ """
273
+ ```julia
274
+ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
205
275
ps = parameters(sys);
206
276
version = nothing, tgrad=false,
207
277
jac = false,
@@ -277,6 +347,7 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
277
347
end
278
348
279
349
function process_DEProblem (constructor, sys:: AbstractODESystem ,u0map,parammap;
350
+ implicit_dae = false , du0map = nothing ,
280
351
version = nothing , tgrad= false ,
281
352
jac = false ,
282
353
checkbounds = false , sparse = false ,
@@ -287,27 +358,80 @@ function process_DEProblem(constructor, sys::AbstractODESystem,u0map,parammap;
287
358
dvs = states (sys)
288
359
ps = parameters (sys)
289
360
defs = defaults (sys)
361
+ iv = independent_variable (sys)
290
362
291
363
u0 = varmap_to_vars (u0map,dvs; defaults= defs)
364
+ if implicit_dae && du0map != = nothing
365
+ ddvs = map (Differential (iv), dvs)
366
+ du0 = varmap_to_vars (du0map, ddvs; defaults= defaults, toterm= identity)
367
+ else
368
+ du0 = nothing
369
+ ddvs = nothing
370
+ end
292
371
p = varmap_to_vars (parammap,ps; defaults= defs)
293
372
294
373
if u0 != = nothing
295
374
length (dvs) == length (u0) || throw (ArgumentError (" States ($(length (dvs)) ) and initial conditions ($(length (u0)) ) are of different lengths." ))
296
375
end
297
376
298
- f = constructor (sys,dvs,ps,u0;tgrad= tgrad,jac= jac,checkbounds= checkbounds,
377
+ f = constructor (sys,dvs,ps,u0;ddvs = ddvs, tgrad= tgrad,jac= jac,checkbounds= checkbounds,
299
378
linenumbers= linenumbers,parallel= parallel,simplify= simplify,
300
379
sparse= sparse,eval_expression= eval_expression,kwargs... )
301
- return f, u0, p
380
+ implicit_dae ? ( f, du0, u0, p) : (f, u0, p)
302
381
end
303
382
304
383
function ODEFunctionExpr (sys:: AbstractODESystem , args... ; kwargs... )
305
384
ODEFunctionExpr {true} (sys, args... ; kwargs... )
306
385
end
307
386
387
+ """
388
+ ```julia
389
+ function DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
390
+ ps = parameters(sys);
391
+ version = nothing, tgrad=false,
392
+ jac = false,
393
+ sparse = false,
394
+ kwargs...) where {iip}
395
+ ```
396
+
397
+ Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref).
398
+ The arguments `dvs` and `ps` are used to set the order of the dependent
399
+ variable and parameter vectors, respectively.
400
+ """
401
+ struct DAEFunctionExpr{iip} end
308
402
309
- function DiffEqBase. ODEProblem (sys:: AbstractODESystem , args... ; kwargs... )
310
- ODEProblem {true} (sys, args... ; kwargs... )
403
+ struct DAEFunctionClosure{O, I} <: Function
404
+ f_oop:: O
405
+ f_iip:: I
406
+ end
407
+ (f:: DAEFunctionClosure )(du, u, p, t) = f. f_oop (du, u, p, t)
408
+ (f:: DAEFunctionClosure )(out, du, u, p, t) = f. f_iip (out, du, u, p, t)
409
+
410
+ function DAEFunctionExpr {iip} (sys:: AbstractODESystem , dvs = states (sys),
411
+ ps = parameters (sys), u0 = nothing ;
412
+ version = nothing , tgrad= false ,
413
+ jac = false ,
414
+ linenumbers = false ,
415
+ sparse = false , simplify= false ,
416
+ kwargs... ) where {iip}
417
+ f_oop, f_iip = generate_function (sys, dvs, ps; expression= Val{true }, implicit_dae = true , kwargs... )
418
+ fsym = gensym (:f )
419
+ _f = :($ fsym = $ DAEFunctionClosure ($ f_oop, $ f_iip))
420
+ ex = quote
421
+ $ _f
422
+ ODEFunction {$iip} ($ fsym,)
423
+ end
424
+ ! linenumbers ? striplines (ex) : ex
425
+ end
426
+
427
+ function DAEFunctionExpr (sys:: AbstractODESystem , args... ; kwargs... )
428
+ DAEFunctionExpr {true} (sys, args... ; kwargs... )
429
+ end
430
+
431
+ for P in [:ODEProblem , :DAEProblem ]
432
+ @eval function DiffEqBase. $P (sys:: AbstractODESystem , args... ; kwargs... )
433
+ $ P {true} (sys, args... ; kwargs... )
434
+ end
311
435
end
312
436
313
437
"""
333
457
334
458
"""
335
459
```julia
336
- function DiffEqBase.ODEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
460
+ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
461
+ parammap=DiffEqBase.NullParameters();
462
+ version = nothing, tgrad=false,
463
+ jac = false,
464
+ checkbounds = false, sparse = false,
465
+ simplify=false,
466
+ linenumbers = true, parallel=SerialForm(),
467
+ kwargs...) where iip
468
+ ```
469
+
470
+ Generates an DAEProblem from an ODESystem and allows for automatically
471
+ symbolically calculating numerical enhancements.
472
+ """
473
+ function DiffEqBase. DAEProblem {iip} (sys:: AbstractODESystem ,du0map,u0map,tspan,
474
+ parammap= DiffEqBase. NullParameters ();kwargs... ) where iip
475
+ f, du0, u0, p = process_DEProblem (
476
+ DAEFunction{iip}, sys, u0map, parammap;
477
+ implicit_dae= true , du0map= du0map, kwargs...
478
+ )
479
+ diffvars = collect_differential_variables (sys)
480
+ sts = states (sys)
481
+ differential_vars = map (Base. Fix2 (in, diffvars), sts)
482
+ DAEProblem {iip} (f,du0,u0,tspan,p;differential_vars= differential_vars,kwargs... )
483
+ end
484
+
485
+ """
486
+ ```julia
487
+ function ODEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
337
488
parammap=DiffEqBase.NullParameters();
338
489
version = nothing, tgrad=false,
339
490
jac = false,
@@ -371,6 +522,53 @@ function ODEProblemExpr(sys::AbstractODESystem, args...; kwargs...)
371
522
ODEProblemExpr {true} (sys, args... ; kwargs... )
372
523
end
373
524
525
+ """
526
+ ```julia
527
+ function DAEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
528
+ parammap=DiffEqBase.NullParameters();
529
+ version = nothing, tgrad=false,
530
+ jac = false,
531
+ checkbounds = false, sparse = false,
532
+ linenumbers = true, parallel=SerialForm(),
533
+ skipzeros=true, fillzeros=true,
534
+ simplify=false,
535
+ kwargs...) where iip
536
+ ```
537
+
538
+ Generates a Julia expression for constructing an ODEProblem from an
539
+ ODESystem and allows for automatically symbolically calculating
540
+ numerical enhancements.
541
+ """
542
+ struct DAEProblemExpr{iip} end
543
+
544
+ function DAEProblemExpr {iip} (sys:: AbstractODESystem ,du0map,u0map,tspan,
545
+ parammap= DiffEqBase. NullParameters ();
546
+ kwargs... ) where iip
547
+ f, du0, u0, p = process_DEProblem (
548
+ DAEFunctionExpr{iip}, sys, u0map, parammap;
549
+ implicit_dae= true , du0map= du0map, kwargs...
550
+ )
551
+ linenumbers = get (kwargs, :linenumbers , true )
552
+ diffvars = collect_differential_variables (sys)
553
+ sts = states (sys)
554
+ differential_vars = map (Base. Fix2 (in, diffvars), sts)
555
+
556
+ ex = quote
557
+ f = $ f
558
+ u0 = $ u0
559
+ du0 = $ du0
560
+ tspan = $ tspan
561
+ p = $ p
562
+ differential_vars = $ differential_vars
563
+ DAEProblem {$iip} (f,du0,u0,tspan,p;differential_vars= differential_vars,$ (kwargs... ))
564
+ end
565
+ ! linenumbers ? striplines (ex) : ex
566
+ end
567
+
568
+ function DAEProblemExpr (sys:: AbstractODESystem , args... ; kwargs... )
569
+ DAEProblemExpr {true} (sys, args... ; kwargs... )
570
+ end
571
+
374
572
375
573
# ## Enables Steady State Problems ###
376
574
function DiffEqBase. SteadyStateProblem (sys:: AbstractODESystem , args... ; kwargs... )
0 commit comments