diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 780e68e80..aef3272db 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -587,7 +587,7 @@ $(TYPEDEF) Base for types defining SciML functions. """ -abstract type AbstractSciMLFunction{iip} <: Function end +abstract type AbstractSciMLFunction{iip} end """ $(TYPEDEF) @@ -622,7 +622,7 @@ abstract type AbstractHistoryFunction end """ $(TYPEDEF) """ -abstract type AbstractReactionNetwork <: Function end +abstract type AbstractReactionNetwork end """ $(TYPEDEF) diff --git a/src/function_wrappers.jl b/src/function_wrappers.jl index 25ff6e4cc..fc5a390f7 100644 --- a/src/function_wrappers.jl +++ b/src/function_wrappers.jl @@ -1,4 +1,6 @@ -mutable struct TimeGradientWrapper{iip, fType, uType, P} <: AbstractSciMLFunction{iip} +abstract type AbstractWrappedFunction{iip} end +isinplace(f::AbstractWrappedFunction{iip}) where {iip} = iip +mutable struct TimeGradientWrapper{iip, fType, uType, P} <: AbstractWrappedFunction{iip} f::fType uprev::uType p::P @@ -18,7 +20,7 @@ end (ff::TimeGradientWrapper{false})(t) = ff.f(ff.uprev, ff.p, t) -mutable struct UJacobianWrapper{iip, fType, tType, P} <: AbstractSciMLFunction{iip} +mutable struct UJacobianWrapper{iip, fType, tType, P} <: AbstractWrappedFunction{iip} f::fType t::tType p::P @@ -41,7 +43,7 @@ end (ff::UJacobianWrapper{false})(uprev) = ff.f(uprev, ff.p, ff.t) (ff::UJacobianWrapper{false})(uprev, p, t) = ff.f(uprev, p, t) -mutable struct TimeDerivativeWrapper{iip, F, uType, P} <: AbstractSciMLFunction{iip} +mutable struct TimeDerivativeWrapper{iip, F, uType, P} <: AbstractWrappedFunction{iip} f::F u::uType p::P @@ -58,7 +60,7 @@ end (ff::TimeDerivativeWrapper{true})(du1, t) = ff.f(du1, ff.u, ff.p, t) (ff::TimeDerivativeWrapper{true})(t) = (du1 = similar(ff.u); ff.f(du1, ff.u, ff.p, t); du1) -mutable struct UDerivativeWrapper{iip, F, tType, P} <: AbstractSciMLFunction{iip} +mutable struct UDerivativeWrapper{iip, F, tType, P} <: AbstractWrappedFunction{iip} f::F t::tType p::P @@ -73,7 +75,7 @@ UDerivativeWrapper(f::F, t, p) where {F} = UDerivativeWrapper{isinplace(f, 4)}(f (ff::UDerivativeWrapper{true})(du1, u) = ff.f(du1, u, ff.p, ff.t) (ff::UDerivativeWrapper{true})(u) = (du1 = similar(u); ff.f(du1, u, ff.p, ff.t); du1) -mutable struct ParamJacobianWrapper{iip, fType, tType, uType} <: AbstractSciMLFunction{iip} +mutable struct ParamJacobianWrapper{iip, fType, tType, uType} <: AbstractWrappedFunction{iip} f::fType t::tType u::uType @@ -95,7 +97,7 @@ function (ff::ParamJacobianWrapper{false})(du1, p) du1 .= ff.f(ff.u, p, ff.t) end -mutable struct JacobianWrapper{iip, fType, pType} <: AbstractSciMLFunction{iip} +mutable struct JacobianWrapper{iip, fType, pType} <: AbstractWrappedFunction{iip} f::fType p::pType end