Skip to content

Commit 4e542f2

Browse files
author
oscarddssmith
committed
AutoSpecialize for NonlinearProblem
1 parent ca77954 commit 4e542f2

File tree

2 files changed

+39
-32
lines changed

2 files changed

+39
-32
lines changed

src/problems/nonlinear_problems.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ function IntervalNonlinearProblem(f, tspan, p = NullParameters(); kwargs...)
104104
IntervalNonlinearProblem(IntervalNonlinearFunction(f), tspan, p; kwargs...)
105105
end
106106

107+
108+
_default_nl_specialize(p) = sizeof(p)==0 || ismutable(p) ? AutoSpecialize : FullSpecialize
109+
107110
@doc doc"""
108111
109112
Defines a nonlinear system problem.
@@ -183,7 +186,7 @@ mutable struct NonlinearProblem{uType, isinplace, P, F, K, PT} <:
183186
This is determined automatically, but not inferred.
184187
"""
185188
function NonlinearProblem{iip}(f, u0, p = NullParameters(); kwargs...) where {iip}
186-
NonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...)
189+
NonlinearProblem{iip}(NonlinearFunction{iip, _default_nl_specialize(p)}(f), u0, p; kwargs...)
187190
end
188191
end
189192

@@ -198,7 +201,9 @@ function NonlinearProblem(f::AbstractNonlinearFunction, u0, p = NullParameters()
198201
end
199202

200203
function NonlinearProblem(f, u0, p = NullParameters(); kwargs...)
201-
NonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
204+
iip = isinplace(f, 3)
205+
206+
NonlinearProblem(NonlinearFunction{iip, _default_nl_specialize(p)}(f), u0, p; kwargs...)
202207
end
203208

204209
"""

src/scimlfunctions.jl

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1724,19 +1724,16 @@ For more details on this argument, see the ODEFunction documentation.
17241724
17251725
The fields of the NonlinearFunction type directly match the names of the inputs.
17261726
"""
1727-
struct NonlinearFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt,
1727+
struct NonlinearFunction{iip, specialize, F, TMM, Ta, TJ, JVP, VJP, JP, SP,
17281728
TPJ, O, TCV, SYS, RP, ID} <: AbstractNonlinearFunction{iip}
17291729
f::F
17301730
mass_matrix::TMM
17311731
analytic::Ta
1732-
tgrad::Tt
17331732
jac::TJ
17341733
jvp::JVP
17351734
vjp::VJP
17361735
jac_prototype::JP
17371736
sparsity::SP
1738-
Wfact::TW
1739-
Wfact_t::TWt
17401737
paramjac::TPJ
17411738
observed::O
17421739
colorvec::TCV
@@ -3801,23 +3798,17 @@ end
38013798
SDDEFunction(f::SDDEFunction; kwargs...) = f
38023799

38033800
function NonlinearFunction{iip, specialize}(f;
3804-
mass_matrix = __has_mass_matrix(f) ?
3805-
f.mass_matrix :
3806-
I,
3807-
analytic = __has_analytic(f) ? f.analytic :
3801+
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
3802+
analytic = __has_analytic(f) ? Void(f.analytic) :
38083803
nothing,
3809-
tgrad = __has_tgrad(f) ? f.tgrad : nothing,
3810-
jac = __has_jac(f) ? f.jac : nothing,
3811-
jvp = __has_jvp(f) ? f.jvp : nothing,
3812-
vjp = __has_vjp(f) ? f.vjp : nothing,
3804+
jac = __has_jac(f) ? Void(f.jac) : nothing,
3805+
jvp = __has_jvp(f) ? Void(f.jvp) : nothing,
3806+
vjp = __has_vjp(f) ? Void(f.vjp) : nothing,
38133807
jac_prototype = __has_jac_prototype(f) ?
38143808
f.jac_prototype : nothing,
38153809
sparsity = __has_sparsity(f) ? f.sparsity :
38163810
jac_prototype,
3817-
Wfact = __has_Wfact(f) ? f.Wfact : nothing,
3818-
Wfact_t = __has_Wfact_t(f) ? f.Wfact_t :
3819-
nothing,
3820-
paramjac = __has_paramjac(f) ? f.paramjac :
3811+
paramjac = __has_paramjac(f) ? Void(f.paramjac) :
38213812
nothing,
38223813
syms = nothing,
38233814
paramsyms = nothing,
@@ -3864,40 +3855,51 @@ function NonlinearFunction{iip, specialize}(f;
38643855
sys = sys_or_symbolcache(sys, syms, paramsyms)
38653856
if specialize === NoSpecialize
38663857
NonlinearFunction{iip, specialize,
3858+
Any, Any, Any, Any,
38673859
Any, Any, Any, Any, Any,
3868-
Any, Any, Any, Any, Any,
3869-
Any, Any, Any,
3860+
Any,
38703861
typeof(_colorvec), Any, Any, Any}(_f, mass_matrix,
3871-
analytic, tgrad, jac,
3862+
analytic, jac,
38723863
jvp, vjp,
38733864
jac_prototype,
3874-
sparsity, Wfact,
3875-
Wfact_t, paramjac,
3865+
sparsity, paramjac,
38763866
observed,
38773867
_colorvec, sys, resid_prototype, initialization_data)
3868+
elseif specialize === AutoSpecialize && iip
3869+
NonlinearFunction{iip, specialize,
3870+
Void, typeof(mass_matrix),
3871+
analytic isa Void ? Void : typeof(analytic),
3872+
jac isa Void ? Void : typeof(jac),
3873+
jvp isa Void ? Void : typeof(jvp),
3874+
vjp isa Void ? Void : typeof(vjp),
3875+
typeof(jac_prototype),
3876+
typeof(sparsity), typeof(paramjac),
3877+
observed isa Void ? Void : typeof(observed),
3878+
typeof(_colorvec), typeof(sys), typeof(resid_prototype),
3879+
typeof(initialization_data)}(Void(_f), mass_matrix,
3880+
analytic, jac,
3881+
jvp, vjp, jac_prototype, sparsity, paramjac,
3882+
observed, _colorvec, sys, resid_prototype, initialization_data)
38783883
else
38793884
NonlinearFunction{iip, specialize,
3880-
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
3885+
typeof(_f), typeof(mass_matrix), typeof(analytic),
38813886
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
3882-
typeof(sparsity), typeof(Wfact),
3883-
typeof(Wfact_t), typeof(paramjac),
3887+
typeof(sparsity), typeof(paramjac),
38843888
typeof(observed),
38853889
typeof(_colorvec), typeof(sys), typeof(resid_prototype),
38863890
typeof(initialization_data)}(_f, mass_matrix,
3887-
analytic, tgrad, jac,
3888-
jvp, vjp, jac_prototype, sparsity,
3889-
Wfact,
3890-
Wfact_t, paramjac,
3891+
analytic, jac,
3892+
jvp, vjp, jac_prototype, sparsity, paramjac,
38913893
observed, _colorvec, sys, resid_prototype, initialization_data)
38923894
end
38933895
end
38943896

38953897
function NonlinearFunction{iip}(f; kwargs...) where {iip}
3896-
NonlinearFunction{iip, FullSpecialize}(f; kwargs...)
3898+
NonlinearFunction{iip, AutoSpecialize}(f; kwargs...)
38973899
end
38983900
NonlinearFunction{iip}(f::NonlinearFunction; kwargs...) where {iip} = f
38993901
function NonlinearFunction(f; kwargs...)
3900-
NonlinearFunction{isinplace(f, 3), FullSpecialize}(f; kwargs...)
3902+
NonlinearFunction{isinplace(f, 3), AutoSpecialize}(f; kwargs...)
39013903
end
39023904
NonlinearFunction(f::NonlinearFunction; kwargs...) = f
39033905

0 commit comments

Comments
 (0)