Skip to content

Commit 9f20b03

Browse files
Merge pull request #1770 from SciML/myb/iip_nospecialize
Expose specialization options for ODE{Function, Problem}
2 parents 64703fb + 834b8a0 commit 9f20b03

File tree

5 files changed

+76
-46
lines changed

5 files changed

+76
-46
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ ArrayInterfaceCore = "0.1.1"
5050
Combinatorics = "1"
5151
ConstructionBase = "1"
5252
DataStructures = "0.17, 0.18"
53-
DiffEqBase = "6.83.0"
53+
DiffEqBase = "6.103.0"
5454
DiffEqCallbacks = "2.16"
5555
DiffRules = "0.1, 1.0"
5656
Distributions = "0.23, 0.24, 0.25"
@@ -70,7 +70,7 @@ NonlinearSolve = "0.3.8"
7070
RecursiveArrayTools = "2.3"
7171
Reexport = "0.2, 1"
7272
RuntimeGeneratedFunctions = "0.4.3, 0.5"
73-
SciMLBase = "1.49"
73+
SciMLBase = "1.54"
7474
Setfield = "0.7, 0.8, 1"
7575
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
7676
StaticArrays = "0.10, 0.11, 0.12, 1.0"

src/structural_transformation/codegen.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -329,18 +329,18 @@ function build_torn_function(sys;
329329
end
330330
end
331331

332-
ODEFunction{true}(@RuntimeGeneratedFunction(expr),
333-
sparsity = jacobian_sparsity ?
334-
torn_system_with_nlsolve_jacobian_sparsity(state,
335-
var_eq_matching,
336-
var_sccs,
337-
nlsolve_scc_idxs,
338-
eqs_idxs,
339-
states_idxs) :
340-
nothing,
341-
syms = syms,
342-
observed = observedfun,
343-
mass_matrix = mass_matrix), states
332+
ODEFunction{true, SciMLBase.AutoSpecialize}(@RuntimeGeneratedFunction(expr),
333+
sparsity = jacobian_sparsity ?
334+
torn_system_with_nlsolve_jacobian_sparsity(state,
335+
var_eq_matching,
336+
var_sccs,
337+
nlsolve_scc_idxs,
338+
eqs_idxs,
339+
states_idxs) :
340+
nothing,
341+
syms = syms,
342+
observed = observedfun,
343+
mass_matrix = mass_matrix), states
344344
end
345345
end
346346

src/systems/abstractsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,7 @@ function linearization_function(sys::AbstractSystem, inputs,
10271027
alge_idxs = alge_idxs,
10281028
input_idxs = input_idxs,
10291029
sts = states(sys),
1030-
fun = ODEFunction(sys),
1030+
fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys),
10311031
h = ModelingToolkit.build_explicit_observed_function(sys, outputs),
10321032
chunk = ForwardDiff.Chunk(input_idxs)
10331033

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -248,17 +248,28 @@ function DiffEqBase.ODEFunction(sys::AbstractODESystem, args...; kwargs...)
248248
ODEFunction{true}(sys, args...; kwargs...)
249249
end
250250

251-
function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
252-
ps = parameters(sys), u0 = nothing;
253-
version = nothing, tgrad = false,
254-
jac = false,
255-
eval_expression = true,
256-
sparse = false, simplify = false,
257-
eval_module = @__MODULE__,
258-
steady_state = false,
259-
checkbounds = false,
260-
sparsity = false,
261-
kwargs...) where {iip}
251+
function DiffEqBase.ODEFunction{true}(sys::AbstractODESystem, args...;
252+
kwargs...)
253+
ODEFunction{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
254+
end
255+
256+
function DiffEqBase.ODEFunction{false}(sys::AbstractODESystem, args...;
257+
kwargs...)
258+
ODEFunction{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
259+
end
260+
261+
function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = states(sys),
262+
ps = parameters(sys), u0 = nothing;
263+
version = nothing, tgrad = false,
264+
jac = false, p = nothing,
265+
t = nothing,
266+
eval_expression = true,
267+
sparse = false, simplify = false,
268+
eval_module = @__MODULE__,
269+
steady_state = false,
270+
checkbounds = false,
271+
sparsity = false,
272+
kwargs...) where {iip, specialize}
262273
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression},
263274
expression_module = eval_module, checkbounds = checkbounds,
264275
kwargs...)
@@ -267,6 +278,13 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
267278
f(u, p, t) = f_oop(u, p, t)
268279
f(du, u, p, t) = f_iip(du, u, p, t)
269280

281+
if specialize === SciMLBase.FunctionWrapperSpecialize && iip
282+
if u0 === nothing || p === nothing || t === nothing
283+
error("u0, p, and t must be specified for FunctionWrapperSpecialize on ODEFunction.")
284+
end
285+
f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
286+
end
287+
270288
if tgrad
271289
tgrad_gen = generate_tgrad(sys, dvs, ps;
272290
simplify = simplify,
@@ -338,16 +356,16 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
338356
else
339357
nothing
340358
end
341-
ODEFunction{iip}(f,
342-
sys = sys,
343-
jac = _jac === nothing ? nothing : _jac,
344-
tgrad = _tgrad === nothing ? nothing : _tgrad,
345-
mass_matrix = _M,
346-
jac_prototype = jac_prototype,
347-
syms = Symbol.(states(sys)),
348-
indepsym = Symbol(get_iv(sys)),
349-
observed = observedfun,
350-
sparsity = sparsity ? jacobian_sparsity(sys) : nothing)
359+
ODEFunction{iip, specialize}(f,
360+
sys = sys,
361+
jac = _jac === nothing ? nothing : _jac,
362+
tgrad = _tgrad === nothing ? nothing : _tgrad,
363+
mass_matrix = _M,
364+
jac_prototype = jac_prototype,
365+
syms = Symbol.(states(sys)),
366+
indepsym = Symbol(get_iv(sys)),
367+
observed = observedfun,
368+
sparsity = sparsity ? jacobian_sparsity(sys) : nothing)
351369
end
352370

353371
"""
@@ -371,7 +389,7 @@ end
371389
function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
372390
ps = parameters(sys), u0 = nothing;
373391
ddvs = map(diff2term Differential(get_iv(sys)), dvs),
374-
version = nothing,
392+
version = nothing, p = nothing,
375393
jac = false,
376394
eval_expression = true,
377395
sparse = false, simplify = false,
@@ -463,7 +481,7 @@ end
463481
function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
464482
ps = parameters(sys), u0 = nothing;
465483
version = nothing, tgrad = false,
466-
jac = false,
484+
jac = false, p = nothing,
467485
linenumbers = false,
468486
sparse = false, simplify = false,
469487
steady_state = false,
@@ -542,6 +560,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
542560

543561
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
544562
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = !use_union, use_union)
563+
p = p === nothing ? SciMLBase.NullParameters() : p
564+
545565
if implicit_dae && du0map !== nothing
546566
ddvs = map(Differential(iv), dvs)
547567
defs = mergedefaults(defs, du0map, ddvs)
@@ -555,7 +575,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
555575
check_eqs_u0(eqs, dvs, u0; kwargs...)
556576

557577
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
558-
checkbounds = checkbounds,
578+
checkbounds = checkbounds, p = p,
559579
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
560580
sparse = sparse, eval_expression = eval_expression, kwargs...)
561581
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
@@ -591,7 +611,7 @@ end
591611
function DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
592612
ps = parameters(sys), u0 = nothing;
593613
version = nothing, tgrad = false,
594-
jac = false,
614+
jac = false, p = nothing,
595615
linenumbers = false,
596616
sparse = false, simplify = false,
597617
kwargs...) where {iip}
@@ -629,12 +649,22 @@ function DiffEqBase.ODEProblem(sys::AbstractODESystem, args...; kwargs...)
629649
ODEProblem{true}(sys, args...; kwargs...)
630650
end
631651

632-
function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem, u0map, tspan,
633-
parammap = DiffEqBase.NullParameters();
634-
callback = nothing,
635-
check_length = true, kwargs...) where {iip}
652+
function DiffEqBase.ODEProblem{true}(sys::AbstractODESystem, args...; kwargs...)
653+
ODEProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
654+
end
655+
656+
function DiffEqBase.ODEProblem{false}(sys::AbstractODESystem, args...; kwargs...)
657+
ODEProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
658+
end
659+
660+
function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map, tspan,
661+
parammap = DiffEqBase.NullParameters();
662+
callback = nothing,
663+
check_length = true,
664+
kwargs...) where {iip, specialize}
636665
has_difference = any(isdifferenceeq, equations(sys))
637-
f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap;
666+
f, u0, p = process_DEProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
667+
t = tspan !== nothing ? tspan[1] : tspan,
638668
has_difference = has_difference,
639669
check_length, kwargs...)
640670
cbs = process_events(sys; callback, has_difference, kwargs...)

test/linearize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit
1+
using ModelingToolkit, Test
22

33
# r is an input, and y is an output.
44
@variables t x(t)=0 y(t)=0 u(t)=0 r(t)=0

0 commit comments

Comments
 (0)