@@ -248,17 +248,28 @@ function DiffEqBase.ODEFunction(sys::AbstractODESystem, args...; kwargs...)
248
248
ODEFunction {true} (sys, args... ; kwargs... )
249
249
end
250
250
251
- function DiffEqBase. ODEFunction {iip} (sys:: AbstractODESystem , dvs = states (sys),
252
- ps = parameters (sys), u0 = nothing ;
253
- version = nothing , tgrad = false ,
254
- jac = false ,
255
- eval_expression = true ,
256
- sparse = false , simplify = false ,
257
- eval_module = @__MODULE__ ,
258
- steady_state = false ,
259
- checkbounds = false ,
260
- sparsity = false ,
261
- kwargs... ) where {iip}
251
+ function DiffEqBase. ODEFunction {true} (sys:: AbstractODESystem , args... ;
252
+ kwargs... )
253
+ ODEFunction {true, SciMLBase.AutoSpecialize} (sys, args... ; kwargs... )
254
+ end
255
+
256
+ function DiffEqBase. ODEFunction {false} (sys:: AbstractODESystem , args... ;
257
+ kwargs... )
258
+ ODEFunction {false, SciMLBase.FullSpecialize} (sys, args... ; kwargs... )
259
+ end
260
+
261
+ function DiffEqBase. ODEFunction {iip, specialize} (sys:: AbstractODESystem , dvs = states (sys),
262
+ ps = parameters (sys), u0 = nothing ;
263
+ version = nothing , tgrad = false ,
264
+ jac = false , p = nothing ,
265
+ t = nothing ,
266
+ eval_expression = true ,
267
+ sparse = false , simplify = false ,
268
+ eval_module = @__MODULE__ ,
269
+ steady_state = false ,
270
+ checkbounds = false ,
271
+ sparsity = false ,
272
+ kwargs... ) where {iip, specialize}
262
273
f_gen = generate_function (sys, dvs, ps; expression = Val{eval_expression},
263
274
expression_module = eval_module, checkbounds = checkbounds,
264
275
kwargs... )
@@ -267,6 +278,13 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
267
278
f (u, p, t) = f_oop (u, p, t)
268
279
f (du, u, p, t) = f_iip (du, u, p, t)
269
280
281
+ if specialize === SciMLBase. FunctionWrapperSpecialize && iip
282
+ if u0 === nothing || p === nothing || t === nothing
283
+ error (" u0, p, and t must be specified for FunctionWrapperSpecialize on ODEFunction." )
284
+ end
285
+ f = SciMLBase. wrapfun_iip (f, (u0, u0, p, t))
286
+ end
287
+
270
288
if tgrad
271
289
tgrad_gen = generate_tgrad (sys, dvs, ps;
272
290
simplify = simplify,
@@ -338,16 +356,16 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
338
356
else
339
357
nothing
340
358
end
341
- ODEFunction {iip} (f,
342
- sys = sys,
343
- jac = _jac === nothing ? nothing : _jac,
344
- tgrad = _tgrad === nothing ? nothing : _tgrad,
345
- mass_matrix = _M,
346
- jac_prototype = jac_prototype,
347
- syms = Symbol .(states (sys)),
348
- indepsym = Symbol (get_iv (sys)),
349
- observed = observedfun,
350
- sparsity = sparsity ? jacobian_sparsity (sys) : nothing )
359
+ ODEFunction {iip, specialize } (f,
360
+ sys = sys,
361
+ jac = _jac === nothing ? nothing : _jac,
362
+ tgrad = _tgrad === nothing ? nothing : _tgrad,
363
+ mass_matrix = _M,
364
+ jac_prototype = jac_prototype,
365
+ syms = Symbol .(states (sys)),
366
+ indepsym = Symbol (get_iv (sys)),
367
+ observed = observedfun,
368
+ sparsity = sparsity ? jacobian_sparsity (sys) : nothing )
351
369
end
352
370
353
371
"""
371
389
function DiffEqBase. DAEFunction {iip} (sys:: AbstractODESystem , dvs = states (sys),
372
390
ps = parameters (sys), u0 = nothing ;
373
391
ddvs = map (diff2term ∘ Differential (get_iv (sys)), dvs),
374
- version = nothing ,
392
+ version = nothing , p = nothing ,
375
393
jac = false ,
376
394
eval_expression = true ,
377
395
sparse = false , simplify = false ,
463
481
function ODEFunctionExpr {iip} (sys:: AbstractODESystem , dvs = states (sys),
464
482
ps = parameters (sys), u0 = nothing ;
465
483
version = nothing , tgrad = false ,
466
- jac = false ,
484
+ jac = false , p = nothing ,
467
485
linenumbers = false ,
468
486
sparse = false , simplify = false ,
469
487
steady_state = false ,
@@ -542,6 +560,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
542
560
543
561
u0 = varmap_to_vars (u0map, dvs; defaults = defs, tofloat = true )
544
562
p = varmap_to_vars (parammap, ps; defaults = defs, tofloat = ! use_union, use_union)
563
+ p = p === nothing ? SciMLBase. NullParameters () : p
564
+
545
565
if implicit_dae && du0map != = nothing
546
566
ddvs = map (Differential (iv), dvs)
547
567
defs = mergedefaults (defs, du0map, ddvs)
@@ -555,7 +575,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
555
575
check_eqs_u0 (eqs, dvs, u0; kwargs... )
556
576
557
577
f = constructor (sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
558
- checkbounds = checkbounds,
578
+ checkbounds = checkbounds, p = p,
559
579
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
560
580
sparse = sparse, eval_expression = eval_expression, kwargs... )
561
581
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
591
611
function DAEFunctionExpr {iip} (sys:: AbstractODESystem , dvs = states (sys),
592
612
ps = parameters (sys), u0 = nothing ;
593
613
version = nothing , tgrad = false ,
594
- jac = false ,
614
+ jac = false , p = nothing ,
595
615
linenumbers = false ,
596
616
sparse = false , simplify = false ,
597
617
kwargs... ) where {iip}
@@ -629,12 +649,22 @@ function DiffEqBase.ODEProblem(sys::AbstractODESystem, args...; kwargs...)
629
649
ODEProblem {true} (sys, args... ; kwargs... )
630
650
end
631
651
632
- function DiffEqBase. ODEProblem {iip} (sys:: AbstractODESystem , u0map, tspan,
633
- parammap = DiffEqBase. NullParameters ();
634
- callback = nothing ,
635
- check_length = true , kwargs... ) where {iip}
652
+ function DiffEqBase. ODEProblem {true} (sys:: AbstractODESystem , args... ; kwargs... )
653
+ ODEProblem {true, SciMLBase.AutoSpecialize} (sys, args... ; kwargs... )
654
+ end
655
+
656
+ function DiffEqBase. ODEProblem {false} (sys:: AbstractODESystem , args... ; kwargs... )
657
+ ODEProblem {false, SciMLBase.FullSpecialize} (sys, args... ; kwargs... )
658
+ end
659
+
660
+ function DiffEqBase. ODEProblem {iip, specialize} (sys:: AbstractODESystem , u0map, tspan,
661
+ parammap = DiffEqBase. NullParameters ();
662
+ callback = nothing ,
663
+ check_length = true ,
664
+ kwargs... ) where {iip, specialize}
636
665
has_difference = any (isdifferenceeq, equations (sys))
637
- f, u0, p = process_DEProblem (ODEFunction{iip}, sys, u0map, parammap;
666
+ f, u0, p = process_DEProblem (ODEFunction{iip, specialize}, sys, u0map, parammap;
667
+ t = tspan != = nothing ? tspan[1 ] : tspan,
638
668
has_difference = has_difference,
639
669
check_length, kwargs... )
640
670
cbs = process_events (sys; callback, has_difference, kwargs... )
0 commit comments