diff --git a/Project.toml b/Project.toml index 7fd1bf242..f28000271 100644 --- a/Project.toml +++ b/Project.toml @@ -51,7 +51,6 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [extensions] DiffEqBaseCUDAExt = "CUDA" DiffEqBaseChainRulesCoreExt = "ChainRulesCore" -DiffEqBaseDistributionsExt = "Distributions" DiffEqBaseEnzymeExt = ["ChainRulesCore", "Enzyme"] DiffEqBaseForwardDiffExt = ["ForwardDiff"] DiffEqBaseGTPSAExt = "GTPSA" diff --git a/ext/DiffEqBaseDistributionsExt.jl b/ext/DiffEqBaseDistributionsExt.jl deleted file mode 100644 index e84a5509a..000000000 --- a/ext/DiffEqBaseDistributionsExt.jl +++ /dev/null @@ -1,8 +0,0 @@ -module DiffEqBaseDistributionsExt - -using Distributions, DiffEqBase - -DiffEqBase.handle_distribution_u0(_u0::Distributions.Sampleable) = rand(_u0) -DiffEqBase.isdistribution(_u0::Distributions.Sampleable) = true - -end diff --git a/ext/DiffEqBaseForwardDiffExt.jl b/ext/DiffEqBaseForwardDiffExt.jl index b328eedbc..0fd81e0dc 100644 --- a/ext/DiffEqBaseForwardDiffExt.jl +++ b/ext/DiffEqBaseForwardDiffExt.jl @@ -6,34 +6,13 @@ using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag, AbstractTimeseriesSolution, RecursiveArrayTools, reduce_tup, _promote_tspan, has_continuous_callback import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, prob2dtmin, - promote_tspan, anyeltypedual, isdualtype, value, ODE_DEFAULT_NORM, - InternalITP, nextfloat_tdir, DualEltypeChecker, sse, unitfulvalue + promote_tspan, ODE_DEFAULT_NORM, + InternalITP, nextfloat_tdir +import SciMLBase: isdualtype, DualEltypeChecker, sse, __sum -eltypedual(x) = eltype(x) <: ForwardDiff.Dual -isdualtype(::Type{<:ForwardDiff.Dual}) = true const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1} dualgen(::Type{T}) where {T} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, 1} -# Copy of the other prob2dtmin dispatch, just for optionality -function prob2dtmin(tspan, ::ForwardDiff.Dual, use_end_time) - t1, t2 = tspan - isfinite(t1) || throw(ArgumentError("t0 in the tspan `(t0, t1)` must be finite")) - if use_end_time && isfinite(t2 - t1) - return max(eps(t2), eps(t1)) - else - return max(eps(typeof(t1)), eps(t1)) - end -end - -function hasdualpromote(u0, t::Number) - hasmethod(ArrayInterface.promote_eltype, - Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) && - hasmethod(promote_rule, - Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) && - hasmethod(promote_rule, - Tuple{Type{eltype(u0)}, Type{typeof(t)}}) -end - const NORECOMPILE_IIP_SUPPORTED_ARGS = ( Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}, @@ -111,400 +90,24 @@ function wrapfun_iip(@nospecialize(ff)) FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt) end -promote_dual(::Type{T}, ::Type{T2}) where {T <: ForwardDiff.Dual, T2} = T -function promote_dual(::Type{T}, - ::Type{T2}) where {T <: ForwardDiff.Dual, T2 <: ForwardDiff.Dual} - T -end -promote_dual(::Type{T}, ::Type{T2}) where {T, T2 <: ForwardDiff.Dual} = T2 - -function promote_dual(::Type{T}, - ::Type{T2}) where {T3, T4, V, V2 <: ForwardDiff.Dual, N, N2, - T <: ForwardDiff.Dual{T3, V, N}, - T2 <: ForwardDiff.Dual{T4, V2, N2}} - T2 -end -function promote_dual(::Type{T}, - ::Type{T2}) where {T3, T4, V <: ForwardDiff.Dual, V2, N, N2, - T <: ForwardDiff.Dual{T3, V, N}, - T2 <: ForwardDiff.Dual{T4, V2, N2}} - T -end -function promote_dual(::Type{T}, - ::Type{T2}) where { - T3, V <: ForwardDiff.Dual, V2 <: ForwardDiff.Dual, - N, - T <: ForwardDiff.Dual{T3, V, N}, - T2 <: ForwardDiff.Dual{T3, V2, N}} - ForwardDiff.Dual{T3, promote_dual(V, V2), N} -end - -""" - promote_dual(::Type{T},::Type{T2}) - - -Is like the number promotion system, but always prefers a dual number type above -anything else. For higher order differentiation, it returns the most dualiest of -them all. This is then used to promote `u0` into the suspected highest differentiation -space for solving the equation. -""" -promote_dual(::Type{T}, ::Type{T2}) where {T, T2} = T - -# `reduce` and `map` are specialized on tuples to be unrolled (via recursion) -# Therefore, they can be type stable even with heterogeneous input types. -# We also don't care about allocating any temporaries with them, as it should -# all be unrolled and optimized away. -# Being unrolled also means const prop can work for things like -# `mapreduce(f, op, propertynames(x))` -# where `f` may call `getproperty` and thus have return type dependent -# on the particular symbol. -# `mapreduce` hasn't received any such specialization. -@inline diffeqmapreduce(f::F, op::OP, x::Tuple) where {F, OP} = reduce_tup(op, map(f, x)) -@inline function diffeqmapreduce(f::F, op::OP, x::NamedTuple) where {F, OP} - reduce_tup(op, map(f, x)) -end -# For other container types, we probably just want to call `mapreduce` -@inline diffeqmapreduce(f::F, op::OP, x) where {F, OP} = mapreduce(f, op, x, init = Any) - -getval(::Val{I}) where {I} = I -getval(::Type{Val{I}}) where {I} = I -getval(I::Int) = I - -const DUALCHECK_RECURSION_MAX = 10 - -function (dec::DualEltypeChecker)(::Val{Y}) where {Y} - isdefined(dec.x, Y) || return Any - getval(dec.counter) >= DUALCHECK_RECURSION_MAX && return Any - anyeltypedual(getfield(dec.x, Y), Val{getval(dec.counter)}) -end - -# Untyped dispatch: catch composite types, check all of their fields -""" - anyeltypedual(x) - - -Searches through a type to see if any of its values are parameters. This is used to -then promote other values to match the dual type. For example, if a user passes a parameter - -which is a `Dual` and a `u0` which is a `Float64`, after the first time step, `f(u,p,t) = p*u` -will change `u0` from `Float64` to `Dual`. Thus the state variable always needs to be converted -to a dual number before the solve. Worse still, this needs to be done in the case of -`f(du,u,p,t) = du[1] = p*u[1]`, and thus running `f` and taking the return value is not a valid -way to calculate the required state type. - -But given the properties of automatic differentiation requiring that differentiation of parameters -implies differentiation of state, we assume any dual parameters implies differentiation of state -and then attempt to upconvert `u0` to match that dual-ness. Because this changes types, this needs -to be specified at compiled time and thus cannot have a Bool-based opt out, so in the future this -may be extended to use a preference system to opt-out with a `UPCONVERT_DUALS`. In the case where -upconversion is not done automatically, the user is required to upconvert all initial conditions -themselves, for an example of how this can be confusing to a user see -https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a-forced-differential-equation/82937 -""" -@generated function anyeltypedual(x, ::Type{Val{counter}}) where {counter} - x = x.name === Core.Compiler.typename(Type) ? x.parameters[1] : x - if isdualtype(x) - :($x) - elseif fieldnames(x) === () - :(Any) - elseif counter < DUALCHECK_RECURSION_MAX - T = diffeqmapreduce(x -> anyeltypedual(x, Val{counter + 1}), promote_dual, - x.parameters) - if T === Any || isconcretetype(T) - :($T) - else - :(diffeqmapreduce(DualEltypeChecker($x, $counter + 1), promote_dual, - map(Val, fieldnames($(typeof(x)))))) - end - else - :(Any) - end -end - -const FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE = """ - Failed to automatically detect ForwardDiff compatability of - the parameter object. In order for ForwardDiff.jl automatic - differentiation to work on a solution object, the state of - the differential equation or nonlinear solve (`u0`) needs to - be converted to a Dual type which matches the values being - differentiated. For example, for a loss function loss(p) - where `p`` is a `Vector{Float64}`, this conversion is - equivalent to: - - ```julia - # Convert u0 to match the new Dual element type of `p` - _prob = remake(prob, u0 = eltype(p).(prob.u0)) - ``` - - In most cases, SciML tools are able to do this conversion - automatically. However, it seems you have provided a - parameter type for which this automatic conversion has failed. - - To fix this, you can do the conversion yourself. For example, - if you have a parameter vector being optimized `p` which is - then put into an odd struct, you can manually convert `u0` - to match `p`: - - ```julia - function loss(p) - _prob = remake(prob, u0 = eltype(p).(prob.u0), p = MyStruct(p)) - sol = solve(_prob, ...) - # do stuff on sol - end - ``` - - Or you can define a dispatch on `DiffEqBase.anyeltypedual` - which tells the system what fields to interpret as the - differentiable parts. For example, to support ODESolutions - as parameters we tell it the data is `sol.u` and `sol.t` via: - - ```julia - function DiffEqBase.anyeltypedual(sol::ODESolution, counter = 0) - DiffEqBase.anyeltypedual((sol.u, sol.t)) - end - ``` - - To opt a type out of the dual checking, define an overload - that returns Any. For example: - - ```julia - function DiffEqBase.anyeltypedual(::YourType, ::Type{Val{counter}}) where {counter} - Any - end - ``` - - If you have defined this on a common type which should - be more generally supported, please open a pull request - adding this dispatch. If you need help defining this dispatch, - feel free to open an issue. - """ - -struct ForwardDiffAutomaticDetectionFailure <: Exception end - -function Base.showerror(io::IO, e::ForwardDiffAutomaticDetectionFailure) - print(io, FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE) -end - -function anyeltypedual(::Type{Union{}}) - throw(ForwardDiffAutomaticDetectionFailure()) -end - -function anyeltypedual(::Type{<:AbstractTimeseriesSolution{T, N}}, - ::Type{Val{counter}} = Val{0}) where {T, N, counter} - anyeltypedual(T) -end - -function anyeltypedual( - ::Type{T}, - ::Type{Val{counter}} = Val{0}) where {counter} where {T <: - NonlinearProblem{ - uType, iip, pType}} where {uType, iip, pType} - return anyeltypedual((uType, pType), Val{counter}) -end - -function anyeltypedual( - ::Type{T}, - ::Type{Val{counter}} = Val{0}) where {counter} where {T <: - NonlinearLeastSquaresProblem{ - uType, iip, pType}} where {uType, iip, pType} - return anyeltypedual((uType, pType), Val{counter}) -end - -function anyeltypedual(x::SciMLBase.RecipesBase.AbstractPlot, - ::Type{Val{counter}} = Val{0}) where {counter} - Any -end -function anyeltypedual(x::Returns, ::Type{Val{counter}} = Val{0}) where {counter} - anyeltypedual(x.value, Val{counter}) -end - -Base.@assume_effects :foldable function __anyeltypedual(::Type{T}) where {T} - if T isa Union - promote_dual(anyeltypedual(T.a), anyeltypedual(T.b)) - elseif hasproperty(T, :parameters) - mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any) - else - T - end -end -function anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T} - __anyeltypedual(T) -end - -function anyeltypedual(::Type{T}, - ::Type{Val{counter}} = Val{0}) where {counter} where {T <: - Union{AbstractArray, Set}} - anyeltypedual(eltype(T)) -end -Base.@pure function __anyeltypedual_ntuple(::Type{T}) where {T <: NTuple} - if isconcretetype(eltype(T)) - return eltype(T) - end - if isempty(T.parameters) - Any - else - mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any) - end -end -function anyeltypedual( - ::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T <: NTuple} - __anyeltypedual_ntuple(T) -end - -# Any in this context just means not Dual -function anyeltypedual( - x::SciMLBase.NullParameters, ::Type{Val{counter}} = Val{0}) where {counter} - Any -end - -function anyeltypedual(sol::RecursiveArrayTools.AbstractDiffEqArray, counter = 0) - diffeqmapreduce(anyeltypedual, promote_dual, (sol.u, sol.t)) -end - -function anyeltypedual(prob::Union{ODEProblem, SDEProblem, RODEProblem, DDEProblem}, - ::Type{Val{counter}} = Val{0}) where {counter} - anyeltypedual((prob.u0, prob.p, prob.tspan)) -end - -function anyeltypedual( - prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem, OptimizationProblem}, - ::Type{Val{counter}} = Val{0}) where {counter} - anyeltypedual((prob.u0, prob.p)) -end - -function anyeltypedual(x::Number, ::Type{Val{counter}} = Val{0}) where {counter} - anyeltypedual(typeof(x)) -end -function anyeltypedual( - x::Union{String, Symbol}, ::Type{Val{counter}} = Val{0}) where {counter} - typeof(x) -end -function anyeltypedual(x::Union{AbstractArray{T}, Set{T}}, - ::Type{Val{counter}} = Val{0}) where {counter} where { - T <: - Union{Number, - Symbol, - String}} - anyeltypedual(T) -end -function anyeltypedual(x::Union{AbstractArray{T}, Set{T}}, - ::Type{Val{counter}} = Val{0}) where {counter} where { - T <: Union{ - AbstractArray{ - <:Number, - }, - Set{ - <:Number, - }}} - anyeltypedual(eltype(x)) -end -function anyeltypedual(x::Union{AbstractArray{T}, Set{T}}, - ::Type{Val{counter}} = Val{0}) where {counter} where {N, T <: NTuple{N, <:Number}} - anyeltypedual(eltype(x)) -end - -# Try to avoid this dispatch because it can lead to type inference issues when !isconcrete(eltype(x)) -function anyeltypedual(x::AbstractArray, ::Type{Val{counter}} = Val{0}) where {counter} - if isconcretetype(eltype(x)) - anyeltypedual(eltype(x)) - elseif !isempty(x) && all(i -> isassigned(x, i), 1:length(x)) && - counter < DUALCHECK_RECURSION_MAX - _counter = Val{counter + 1} - mapreduce(y -> anyeltypedual(y, _counter), promote_dual, x) - else - # This fallback to Any is required since otherwise we cannot handle `undef` in all cases - # misses cases of - Any - end -end - -function anyeltypedual(x::Set, ::Type{Val{counter}} = Val{0}) where {counter} - if isconcretetype(eltype(x)) - anyeltypedual(eltype(x)) - else - # This fallback to Any is required since otherwise we cannot handle `undef` in all cases - Any - end -end - -function anyeltypedual(x::Tuple, ::Type{Val{counter}} = Val{0}) where {counter} - # Handle the empty tuple case separately for inference and to avoid mapreduce error - if x === () - Any - else - diffeqmapreduce(anyeltypedual, promote_dual, x) - end -end -function anyeltypedual(x::AbstractDict, ::Type{Val{counter}} = Val{0}) where {counter} - isempty(x) ? eltype(values(x)) : mapreduce(anyeltypedual, promote_dual, values(x)) -end -function anyeltypedual(x::NamedTuple, ::Type{Val{counter}} = Val{0}) where {counter} - anyeltypedual(values(x)) -end - -function anyeltypedual( - f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}}) where {counter} - Any -end - -anyeltypedual(::@Kwargs{}, ::Type{Val{counter}} = Val{0}) where {counter} = Any -anyeltypedual(::Type{@Kwargs{}}, ::Type{Val{counter}} = Val{0}) where {counter} = Any - -# Opt out since these are using for preallocation, not differentiation -function anyeltypedual(x::Union{ForwardDiff.AbstractConfig, Module}, - ::Type{Val{counter}} = Val{0}) where {counter} - Any -end -function anyeltypedual(x::Type{T}, - ::Type{Val{counter}} = Val{0}) where {counter} where {T <: - ForwardDiff.AbstractConfig} - Any -end - -function anyeltypedual(x::ForwardDiff.DiffResults.DiffResult, - ::Type{Val{counter}} = Val{0}) where {counter} - Any -end -function anyeltypedual(x::Type{T}, - ::Type{Val{counter}} = Val{0}) where {counter} where {T <: - ForwardDiff.DiffResults.DiffResult} - Any -end - -function anyeltypedual(::Type{T}, - ::Type{Val{counter}} = Val{0}) where {counter} where {T <: ForwardDiff.Dual} - T -end - -function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p, tspan, prob, kwargs) - if (haskey(kwargs, :callback) && has_continuous_callback(kwargs[:callback])) || - (haskey(prob.kwargs, :callback) && has_continuous_callback(prob.kwargs[:callback])) - return _promote_tspan(eltype(u0).(tspan), kwargs) +# Copy of the other prob2dtmin dispatch, just for optionality +function prob2dtmin(tspan, ::ForwardDiff.Dual, use_end_time) + t1, t2 = tspan + isfinite(t1) || throw(ArgumentError("t0 in the tspan `(t0, t1)` must be finite")) + if use_end_time && isfinite(t2 - t1) + return max(eps(t2), eps(t1)) else - return _promote_tspan(tspan, kwargs) + return max(eps(typeof(t1)), eps(t1)) end end -function promote_tspan(u0::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, p, tspan, prob, - kwargs) - return _promote_tspan(real(eltype(u0)).(tspan), kwargs) -end - -function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p, - tspan::Tuple{<:ForwardDiff.Dual, <:ForwardDiff.Dual}, prob, kwargs) - return _promote_tspan(tspan, kwargs) -end - -value(x::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = V -value(x::ForwardDiff.Dual) = value(ForwardDiff.value(x)) - -unitfulvalue(x::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = V -unitfulvalue(x::ForwardDiff.Dual) = unitfulvalue(ForwardDiff.value(x)) - -sse(x::ForwardDiff.Dual) = sse(ForwardDiff.value(x)) + sum(sse, ForwardDiff.partials(x)) -function DiffEqBase.totallength(x::ForwardDiff.Dual) - return DiffEqBase.totallength(ForwardDiff.value(x)) + - sum(DiffEqBase.totallength, ForwardDiff.partials(x)) +function hasdualpromote(u0, t::Number) + hasmethod(ArrayInterface.promote_eltype, + Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) && + hasmethod(promote_rule, + Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) && + hasmethod(promote_rule, + Tuple{Type{eltype(u0)}, Type{typeof(t)}}) end @inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual, ::Any) = sqrt(sse(u)) diff --git a/ext/DiffEqBaseGTPSAExt.jl b/ext/DiffEqBaseGTPSAExt.jl index 10999bb26..130efde3a 100644 --- a/ext/DiffEqBaseGTPSAExt.jl +++ b/ext/DiffEqBaseGTPSAExt.jl @@ -1,7 +1,8 @@ module DiffEqBaseGTPSAExt using DiffEqBase -import DiffEqBase: value, ODE_DEFAULT_NORM +import DiffEqBase: ODE_DEFAULT_NORM +import SciMLBase: value, unitfulvalue using GTPSA value(x::TPS) = scalar(x) diff --git a/ext/DiffEqBaseMeasurementsExt.jl b/ext/DiffEqBaseMeasurementsExt.jl index a72423708..584e5242b 100644 --- a/ext/DiffEqBaseMeasurementsExt.jl +++ b/ext/DiffEqBaseMeasurementsExt.jl @@ -4,18 +4,6 @@ using DiffEqBase import DiffEqBase: value using Measurements -function DiffEqBase.promote_u0(u0::AbstractArray{<:Measurements.Measurement}, - p::AbstractArray{<:Measurements.Measurement}, t0) - u0 -end -DiffEqBase.promote_u0(u0, p::AbstractArray{<:Measurements.Measurement}, t0) = eltype(p).(u0) - -value(x::Type{Measurements.Measurement{T}}) where {T} = T -value(x::Measurements.Measurement) = Measurements.value(x) - -unitfulvalue(x::Type{Measurements.Measurement{T}}) where {T} = T -unitfulvalue(x::Measurements.Measurement) = Measurements.value(x) - # Support adaptive steps should be errorless @inline function DiffEqBase.ODE_DEFAULT_NORM( u::AbstractArray{ diff --git a/ext/DiffEqBaseMonteCarloMeasurementsExt.jl b/ext/DiffEqBaseMonteCarloMeasurementsExt.jl index d2b33d1a6..d3335a491 100644 --- a/ext/DiffEqBaseMonteCarloMeasurementsExt.jl +++ b/ext/DiffEqBaseMonteCarloMeasurementsExt.jl @@ -4,34 +4,6 @@ using DiffEqBase import DiffEqBase: value using MonteCarloMeasurements -function DiffEqBase.promote_u0( - u0::AbstractArray{ - <:MonteCarloMeasurements.AbstractParticles, - }, - p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, - t0) - u0 -end -function DiffEqBase.promote_u0(u0, - p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, - t0) - eltype(p).(u0) -end - -function DiffEqBase.promote_u0(::Nothing, - p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, - t0) - return nothing -end - -DiffEqBase.value(x::Type{MonteCarloMeasurements.AbstractParticles{T, N}}) where {T, N} = T -DiffEqBase.value(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles) -function DiffEqBase.unitfulvalue(x::Type{MonteCarloMeasurements.AbstractParticles{ - T, N}}) where {T, N} - T -end -DiffEqBase.unitfulvalue(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles) - # Support adaptive steps should be errorless @inline function DiffEqBase.ODE_DEFAULT_NORM( u::AbstractArray{ diff --git a/ext/DiffEqBaseMooncakeExt.jl b/ext/DiffEqBaseMooncakeExt.jl index 16e4b46e5..d0afcd110 100644 --- a/ext/DiffEqBaseMooncakeExt.jl +++ b/ext/DiffEqBaseMooncakeExt.jl @@ -29,18 +29,4 @@ import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive, }, true,) -@zero_adjoint MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any} -@is_primitive MinimalCtx Tuple{ - typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake), SciMLBase.ChainRulesOriginator -} - -@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = SciMLBase.MooncakeOriginator() - -function rrule!!( - f::CoDual{typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake)}, - X::CoDual{SciMLBase.ChainRulesOriginator} -) - return zero_fcodual(SciMLBase.MooncakeOriginator()), NoPullback(f, X) -end - end diff --git a/ext/DiffEqBaseReverseDiffExt.jl b/ext/DiffEqBaseReverseDiffExt.jl index 879989d30..a69c5715d 100644 --- a/ext/DiffEqBaseReverseDiffExt.jl +++ b/ext/DiffEqBaseReverseDiffExt.jl @@ -5,57 +5,6 @@ import DiffEqBase: value import ReverseDiff import DiffEqBase.ArrayInterface -function DiffEqBase.anyeltypedual(::Type{T}, - ::Type{Val{counter}} = Val{0}) where {counter} where { - V, D, N, VA, DA, T <: ReverseDiff.TrackedArray{V, D, N, VA, DA}} - DiffEqBase.anyeltypedual(V, Val{counter}) -end - -DiffEqBase.value(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V -function DiffEqBase.value(x::Type{ - ReverseDiff.TrackedArray{V, D, N, VA, DA}, -}) where {V, D, - N, VA, - DA} - Array{V, N} -end -DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value -DiffEqBase.value(x::ReverseDiff.TrackedArray) = x.value - -DiffEqBase.unitfulvalue(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V -function DiffEqBase.unitfulvalue(x::Type{ - ReverseDiff.TrackedArray{V, D, N, VA, DA}, -}) where {V, D, - N, VA, - DA} - Array{V, N} -end -DiffEqBase.unitfulvalue(x::ReverseDiff.TrackedReal) = x.value -DiffEqBase.unitfulvalue(x::ReverseDiff.TrackedArray) = x.value - -# Force TrackedArray from TrackedReal when reshaping W\b -DiffEqBase._reshape(v::AbstractVector{<:ReverseDiff.TrackedReal}, siz) = reduce(vcat, v) - -DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0 -function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, - p::ReverseDiff.TrackedArray, t0) - u0 -end -function DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, - p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) - u0 -end -function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, - p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) - u0 -end -DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0) -function DiffEqBase.promote_u0( - u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ReverseDiff.ForwardDiff.Dual} - ReverseDiff.track(T.(u0)) -end -DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0) - # Support adaptive with non-tracked time @inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t) sqrt(sum(abs2, DiffEqBase.value(u)) / length(u)) diff --git a/ext/DiffEqBaseTrackerExt.jl b/ext/DiffEqBaseTrackerExt.jl index 72c869fb8..6342cb86b 100644 --- a/ext/DiffEqBaseTrackerExt.jl +++ b/ext/DiffEqBaseTrackerExt.jl @@ -4,36 +4,6 @@ using DiffEqBase import DiffEqBase: value import Tracker -DiffEqBase.value(x::Type{Tracker.TrackedReal{T}}) where {T} = T -DiffEqBase.value(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} = Array{T, N} -DiffEqBase.value(x::Tracker.TrackedReal) = x.data -DiffEqBase.value(x::Tracker.TrackedArray) = x.data - -DiffEqBase.unitfulvalue(x::Type{Tracker.TrackedReal{T}}) where {T} = T -function DiffEqBase.unitfulvalue(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} - Array{T, N} -end -DiffEqBase.unitfulvalue(x::Tracker.TrackedReal) = x.data -DiffEqBase.unitfulvalue(x::Tracker.TrackedArray) = x.data - -DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::Tracker.TrackedArray, t0) = u0 -function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, - p::Tracker.TrackedArray, t0) - u0 -end -function DiffEqBase.promote_u0(u0::Tracker.TrackedArray, - p::AbstractArray{<:Tracker.TrackedReal}, t0) - u0 -end -function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, - p::AbstractArray{<:Tracker.TrackedReal}, t0) - u0 -end -DiffEqBase.promote_u0(u0, p::Tracker.TrackedArray, t0) = Tracker.track(u0) -DiffEqBase.promote_u0(u0, p::AbstractArray{<:Tracker.TrackedReal}, t0) = eltype(p).(u0) - -@inline Base.any(f::Function, x::Tracker.TrackedArray) = any(f, Tracker.data(x)) - # Support adaptive with non-tracked time @inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, t) sqrt(sum(abs2, DiffEqBase.value(u)) / length(u)) diff --git a/ext/DiffEqBaseUnitfulExt.jl b/ext/DiffEqBaseUnitfulExt.jl index 381b0e516..85dd93f72 100644 --- a/ext/DiffEqBaseUnitfulExt.jl +++ b/ext/DiffEqBaseUnitfulExt.jl @@ -1,7 +1,7 @@ module DiffEqBaseUnitfulExt using DiffEqBase -import DiffEqBase: value +import SciMLBase: unitfulvalue, value using Unitful # Support adaptive errors should be errorless for exponentiation diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index a6cf2fa0c..1a0466ef3 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -94,7 +94,17 @@ using SciMLBase: @def, DEIntegrator, AbstractDEProblem, import SciMLBase: solve, init, step!, solve!, __init, __solve, update_coefficients!, update_coefficients, isadaptive, wrapfun_oop, wrapfun_iip, - unwrap_fw, promote_tspan, set_u!, set_t!, set_ut! + unwrap_fw, promote_tspan, set_u!, set_t!, set_ut!, + extract_alg, checkkwargs, has_kwargs, _concrete_solve_adjoint, _concrete_solve_forward, + eltypedual, get_updated_symbolic_problem, get_concrete_p, get_concrete_u0, promote_u0, + isconcreteu0, isconcretedu0, get_concrete_du0, _reshape, value, unitfulvalue, anyeltypedual, allowedkeywords, + sse, totallength, __sum, DualEltypeChecker, KeywordArgError, KeywordArgWarn, KeywordArgSilent, KWARGWARN_MESSAGE, KWARGERROR_MESSAGE, + CommonKwargError, IncompatibleInitialConditionError, NO_DEFAULT_ALGORITHM_MESSAGE, NoDefaultAlgorithmError, NO_TSPAN_MESSAGE, NoTspanError, + NAN_TSPAN_MESSAGE, NaNTspanError, NON_SOLVER_MESSAGE, NonSolverError, NOISE_SIZE_MESSAGE, NoiseSizeIncompatabilityError, PROBSOLVER_PAIRING_MESSAGE, + ProblemSolverPairingError, compatible_problem_types, DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE, DirectAutodiffError, NONNUMBER_ELTYPE_MESSAGE, NonNumberEltypeError, + GENERIC_NUMBER_TYPE_ERROR_MESSAGE, GenericNumberTypeError, COMPLEX_SUPPORT_ERROR_MESSAGE, ComplexSupportError, COMPLEX_TSPAN_ERROR_MESSAGE, ComplexTspanError, + TUPLE_STATE_ERROR_MESSAGE, TupleStateError, MASS_MATRIX_ERROR_MESSAGE, IncompatibleMassMatrixError, LATE_BINDING_TSTOPS_ERROR_MESSAGE, LateBindingTstopsNotSupportedError, + NONCONCRETE_ELTYPE_MESSAGE, NonConcreteEltypeError, _vec import SciMLStructures @@ -107,10 +117,6 @@ import SymbolicIndexingInterface as SII ## Extension Functions -eltypedual(x) = false -promote_u0(::Nothing, p, t0) = nothing -isdualtype(::Type{T}) where {T} = false - ## Types """ @@ -167,6 +173,4 @@ export initialize!, finalize! export SensitivityADPassThrough -export KeywordArgError, KeywordArgWarn, KeywordArgSilent - end # module diff --git a/src/solve.jl b/src/solve.jl index 25eef6e11..512f1c689 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -7,539 +7,9 @@ NO_TSPAN_PROBS = Union{AbstractLinearProblem, AbstractNonlinearProblem, AbstractIntegralProblem, AbstractSteadyStateProblem, AbstractJumpProblem} -has_kwargs(_prob::AbstractDEProblem) = has_kwargs(typeof(_prob)) -Base.@pure __has_kwargs(::Type{T}) where {T} = :kwargs ∈ fieldnames(T) -has_kwargs(::Type{T}) where {T} = __has_kwargs(T) - -const allowedkeywords = (:dense, - :saveat, - :save_idxs, - :tstops, - :tspan, - :d_discontinuities, - :save_everystep, - :save_on, - :save_start, - :save_end, - :initialize_save, - :adaptive, - :abstol, - :reltol, - :dt, - :dtmax, - :dtmin, - :force_dtmin, - :internalnorm, - :controller, - :gamma, - :beta1, - :beta2, - :qmax, - :qmin, - :qsteady_min, - :qsteady_max, - :qoldinit, - :failfactor, - :calck, - :alias_u0, - :maxiters, - :maxtime, - :callback, - :isoutofdomain, - :unstable_check, - :verbose, - :merge_callbacks, - :progress, - :progress_steps, - :progress_name, - :progress_message, - :progress_id, - :timeseries_errors, - :dense_errors, - :weak_timeseries_errors, - :weak_dense_errors, - :wrap, - :calculate_error, - :initializealg, - :alg, - :save_noise, - :delta, - :seed, - :alg_hints, - :kwargshandle, - :trajectories, - :batch_size, - :sensealg, - :advance_to_tstop, - :stop_at_next_tstop, - :u0, - :p, - # These two are from the default algorithm handling - :default_set, - :second_time, - # This is for DiffEqDevTools - :prob_choice, - # Jump problems - :alias_jump, - # This is for copying/deepcopying noise in StochasticDiffEq - :alias_noise, - # This is for SimpleNonlinearSolve handling for batched Nonlinear Solves - :batch, - # Shooting method in BVP needs to differentiate between these two categories - :nlsolve_kwargs, - :odesolve_kwargs, - # If Solvers which internally use linsolve - :linsolve_kwargs, - # Solvers internally using EnsembleProblem - :ensemblealg, - # Fine Grained Control of Tracing (Storing and Logging) during Solve - :show_trace, - :trace_level, - :store_trace, - # Termination condition for solvers - :termination_condition, - # For AbstractAliasSpecifier - :alias, - # Parameter estimation with BVP - :fit_parameters) - -const KWARGWARN_MESSAGE = """ - Unrecognized keyword arguments found. - The only allowed keyword arguments to `solve` are: - $allowedkeywords - - See https://diffeq.sciml.ai/stable/basics/common_solver_opts/ for more details. - - Set kwargshandle=KeywordArgError for an error message. - Set kwargshandle=KeywordArgSilent to ignore this message. - """ - -const KWARGERROR_MESSAGE = """ - Unrecognized keyword arguments found. - The only allowed keyword arguments to `solve` are: - $allowedkeywords - - See https://diffeq.sciml.ai/stable/basics/common_solver_opts/ for more details. - """ - -struct CommonKwargError <: Exception - kwargs::Any -end - -function Base.showerror(io::IO, e::CommonKwargError) - println(io, KWARGERROR_MESSAGE) - notin = collect(map(x -> x ∉ allowedkeywords, keys(e.kwargs))) - unrecognized = collect(keys(e.kwargs))[notin] - print(io, "Unrecognized keyword arguments: ") - printstyled(io, unrecognized; bold = true, color = :red) - print(io, "\n\n") - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -@enum KeywordArgError KeywordArgWarn KeywordArgSilent - -const INCOMPATIBLE_U0_MESSAGE = """ - Initial condition incompatible with functional form. - Detected an in-place function with an initial condition of type Number or SArray. - This is incompatible because Numbers cannot be mutated, i.e. - `x = 2.0; y = 2.0; x .= y` will error. - - If using a immutable initial condition type, please use the out-of-place form. - I.e. define the function `du=f(u,p,t)` instead of attempting to "mutate" the immutable `du`. - - If your differential equation function was defined with multiple dispatches and one is - in-place, then the automatic detection will choose in-place. In this case, override the - choice in the problem constructor, i.e. `ODEProblem{false}(f,u0,tspan,p,kwargs...)`. - - For a longer discussion on mutability vs immutability and in-place vs out-of-place, see: - https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Example-Accelerating-a-Non-Stiff-Equation:-The-Lorenz-Equation - """ - -struct IncompatibleInitialConditionError <: Exception end - -function Base.showerror(io::IO, e::IncompatibleInitialConditionError) - print(io, INCOMPATIBLE_U0_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const NO_DEFAULT_ALGORITHM_MESSAGE = """ - Default algorithm choices require DifferentialEquations.jl. - Please specify an algorithm (e.g., `solve(prob, Tsit5())` or - `init(prob, Tsit5())` for an ODE) or import DifferentialEquations - directly. - - You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ - and its associated pages. - """ - -struct NoDefaultAlgorithmError <: Exception end - -function Base.showerror(io::IO, e::NoDefaultAlgorithmError) - print(io, NO_DEFAULT_ALGORITHM_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const NO_TSPAN_MESSAGE = """ - No tspan is set in the problem or chosen in the init/solve call - """ - -struct NoTspanError <: Exception end - -function Base.showerror(io::IO, e::NoTspanError) - print(io, NO_TSPAN_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const NAN_TSPAN_MESSAGE = """ - NaN tspan is set in the problem or chosen in the init/solve call. - Note that -Inf and Inf values are allowed in the timespan for solves - which are terminated via callbacks, however NaN values are not allowed - since the direction of time is undetermined. - """ - -struct NaNTspanError <: Exception end - -function Base.showerror(io::IO, e::NaNTspanError) - print(io, NAN_TSPAN_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const NON_SOLVER_MESSAGE = """ - The arguments to solve are incorrect. - The second argument must be a solver choice, `solve(prob,alg)` - where `alg` is a `<: AbstractDEAlgorithm`, e.g. `Tsit5()`. - - Please double check the arguments being sent to the solver. - - You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ - and its associated pages. - """ - -struct NonSolverError <: Exception end - -function Base.showerror(io::IO, e::NonSolverError) - print(io, NON_SOLVER_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const NOISE_SIZE_MESSAGE = """ - Noise sizes are incompatible. The expected number of noise terms in the defined - `noise_rate_prototype` does not match the number of noise terms in the defined - `AbstractNoiseProcess`. Please ensure that - size(prob.noise_rate_prototype,2) == length(prob.noise.W[1]). - - Note: Noise process definitions require that users specify `u0`, and this value is - directly used in the definition. For example, if `noise = WienerProcess(0.0,0.0)`, - then the noise process is a scalar with `u0=0.0`. If `noise = WienerProcess(0.0,[0.0])`, - then the noise process is a vector with `u0=0.0`. If `noise_rate_prototype = zeros(2,4)`, - then the noise process must be a 4-dimensional process, for example - `noise = WienerProcess(0.0,zeros(4))`. This error is a sign that the user definition - of `noise_rate_prototype` and `noise` are not aligned in this manner and the definitions should - be double checked. - """ - -struct NoiseSizeIncompatabilityError <: Exception - prototypesize::Int - noisesize::Int -end - -function Base.showerror(io::IO, e::NoiseSizeIncompatabilityError) - println(io, NOISE_SIZE_MESSAGE) - println(io, "size(prob.noise_rate_prototype,2) = $(e.prototypesize)") - println(io, "length(prob.noise.W[1]) = $(e.noisesize)") - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const PROBSOLVER_PAIRING_MESSAGE = """ - Incompatible problem+solver pairing. - For example, this can occur if an ODE solver is passed with an SDEProblem. - Solvers are only capable of handling specific problem types. Please double - check that the chosen pairing is capable for handling the given problems. - """ - -struct ProblemSolverPairingError <: Exception - prob::Any - alg::Any -end - -function Base.showerror(io::IO, e::ProblemSolverPairingError) - println(io, PROBSOLVER_PAIRING_MESSAGE) - println(io, "Problem type: $(SciMLBase.__parameterless_type(typeof(e.prob)))") - println(io, "Solver type: $(SciMLBase.__parameterless_type(typeof(e.alg)))") - println(io, - "Problem types compatible with the chosen solver: $(compatible_problem_types(e.prob,e.alg))") - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -function compatible_problem_types(prob, alg) - if alg isa AbstractODEAlgorithm - ODEProblem - elseif alg isa AbstractSDEAlgorithm - (SDEProblem, SDDEProblem) - elseif alg isa AbstractDDEAlgorithm # StochasticDelayDiffEq.jl just uses the SDE alg - DDEProblem - elseif alg isa AbstractDAEAlgorithm - DAEProblem - elseif alg isa AbstractSteadyStateAlgorithm - SteadyStateProblem - end -end - -const DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE = """ - Incompatible solver + automatic differentiation pairing. - The chosen automatic differentiation algorithm requires the ability - for compiler transforms on the code which is only possible on pure-Julia - solvers such as those from OrdinaryDiffEq.jl. Direct differentiation methods - which require this ability include: - - - Direct use of ForwardDiff.jl on the solver - - `ForwardDiffSensitivity`, `ReverseDiffAdjoint`, `TrackerAdjoint`, and `ZygoteAdjoint` - sensealg choices for adjoint differentiation. - - Either switch the choice of solver to a pure Julia method, or change the automatic - differentiation method to one that does not require such transformations. - - For more details on automatic differentiation, adjoint, and sensitivity analysis - of differential equations, see the documentation page: - - https://diffeq.sciml.ai/stable/analysis/sensitivity/ - """ - -struct DirectAutodiffError <: Exception end - -function Base.showerror(io::IO, e::DirectAutodiffError) - println(io, DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const NONCONCRETE_ELTYPE_MESSAGE = """ - Non-concrete element type inside of an `Array` detected. - Arrays with non-concrete element types, such as - `Array{Union{Float32,Float64}}`, are not supported by the - differential equation solvers. Anyways, this is bad for - performance so you don't want to be doing this! - - If this was a mistake, promote the element types to be - all the same. If this was intentional, for example, - using Unitful.jl with different unit values, then use - an array type which has fast broadcast support for - heterogeneous values such as the ArrayPartition - from RecursiveArrayTools.jl. For example: - - ```julia - using RecursiveArrayTools - x = ArrayPartition([1.0,2.0],[1f0,2f0]) - y = ArrayPartition([3.0,4.0],[3f0,4f0]) - x .+ y # fast, stable, and usable as u0 into DiffEq! - ``` - - Element type: - """ - -struct NonConcreteEltypeError <: Exception - eltype::Any -end - -function Base.showerror(io::IO, e::NonConcreteEltypeError) - print(io, NONCONCRETE_ELTYPE_MESSAGE) - print(io, e.eltype) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const NONNUMBER_ELTYPE_MESSAGE = """ - Non-Number element type inside of an `Array` detected. - Arrays with non-number element types, such as - `Array{Array{Float64}}`, are not supported by the - solvers. - - If you are trying to use an array of arrays structure, - look at the tools in RecursiveArrayTools.jl. For example: - - If this was a mistake, promote the element types to be - all the same. If this was intentional, for example, - using Unitful.jl with different unit values, then use - an array type which has fast broadcast support for - heterogeneous values such as the ArrayPartition - from RecursiveArrayTools.jl. For example: - - ```julia - using RecursiveArrayTools - u0 = ArrayPartition([1.0,2.0],[3.0,4.0]) - u0 = VectorOfArray([1.0,2.0],[3.0,4.0]) - ``` - - are both initial conditions which would be compatible with - the solvers. Or use ComponentArrays.jl for more complex - nested structures. - - Element type: - """ - -struct NonNumberEltypeError <: Exception - eltype::Any -end - -function Base.showerror(io::IO, e::NonNumberEltypeError) - print(io, NONNUMBER_ELTYPE_MESSAGE) - print(io, e.eltype) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const GENERIC_NUMBER_TYPE_ERROR_MESSAGE = """ - Non-standard number type (i.e. not Float32, Float64, - ComplexF32, or ComplexF64) detected as the element type - for the initial condition or time span. These generic - number types are only compatible with the pure Julia - solvers which support generic programming, such as - OrdinaryDiffEq.jl. The chosen solver does not support - this functionality. Please double check that the initial - condition and time span types are correct, and check that - the chosen solver was correct. - """ - -struct GenericNumberTypeError <: Exception - alg::Any - uType::Any - tType::Any -end - -function Base.showerror(io::IO, e::GenericNumberTypeError) - println(io, GENERIC_NUMBER_TYPE_ERROR_MESSAGE) - println(io, "Solver: $(e.alg)") - println(io, "u0 type: $(e.uType)") - print(io, "Timespan type: $(e.tType)") - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const COMPLEX_SUPPORT_ERROR_MESSAGE = """ - Complex number type (i.e. ComplexF32, or ComplexF64) - detected as the element type for the initial condition - with an algorithm that does not support complex numbers. - Please check that the initial condition type is correct. - If complex number support is needed, try different solvers - such as those from OrdinaryDiffEq.jl. - """ - -struct ComplexSupportError <: Exception - alg::Any -end - -function Base.showerror(io::IO, e::ComplexSupportError) - println(io, COMPLEX_SUPPORT_ERROR_MESSAGE) - println(io, "Solver: $(e.alg)") - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const COMPLEX_TSPAN_ERROR_MESSAGE = """ - Complex number type (i.e. ComplexF32, or ComplexF64) - detected as the element type for the independent variable - (i.e. time span). Please check that the tspan type is correct. - No solvers support complex time spans. If this is required, - please open an issue. - """ - -struct ComplexTspanError <: Exception end - -function Base.showerror(io::IO, e::ComplexTspanError) - println(io, COMPLEX_TSPAN_ERROR_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const TUPLE_STATE_ERROR_MESSAGE = """ - Tuple type used as a state. Since a tuple does not have vector - properties, it will not work as a state type in equation solvers. - Instead, change your equation from using tuple constructors `()` - to static array constructors `SA[]`. For example, change: - - ```julia - function ftup((a,b),p,t) - return b,-a - end - u0 = (1.0,2.0) - tspan = (0.0,1.0) - ODEProblem(ftup,u0,tspan) - ``` - - to: - - ```julia - using StaticArrays - function fsa(u,p,t) - SA[u[2],u[1]] - end - u0 = SA[1.0,2.0] - tspan = (0.0,1.0) - ODEProblem(ftup,u0,tspan) - ``` - - This will be safer and fast for small ODEs. For more information, see: - https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Further-Optimizations-of-Small-Non-Stiff-ODEs-with-StaticArrays - """ - -struct TupleStateError <: Exception end - -function Base.showerror(io::IO, e::TupleStateError) - println(io, TUPLE_STATE_ERROR_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const MASS_MATRIX_ERROR_MESSAGE = """ - Mass matrix size is incompatible with initial condition - sizing. The mass matrix must represent the `vec` - form of the initial condition `u0`, i.e. - `size(mm,1) == size(mm,2) == length(u)` - """ - -struct IncompatibleMassMatrixError <: Exception - sz::Int - len::Int -end - -function Base.showerror(io::IO, e::IncompatibleMassMatrixError) - println(io, MASS_MATRIX_ERROR_MESSAGE) - print(io, "size(prob.f.mass_matrix,1): ") - println(io, e.sz) - print(io, "length(u0): ") - println(e.len) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -const LATE_BINDING_TSTOPS_ERROR_MESSAGE = """ - This solver does not support providing `tstops` as a function. - Consider using a different solver or providing `tstops` as an array - of times. - """ - -struct LateBindingTstopsNotSupportedError <: Exception end - -function Base.showerror(io::IO, e::LateBindingTstopsNotSupportedError) - println(io, LATE_BINDING_TSTOPS_ERROR_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -""" - $(TYPEDSIGNATURES) - -Given the index provider `indp` used to construct the problem `prob` being solved, return -an updated `prob` to be used for solving. All implementations should accept arbitrary -keyword arguments. - -Should be called before the problem is solved, after performing type-promotion on the -problem. If the returned problem is not `===` the provided `prob`, it is assumed to -contain the `u0` and `p` passed as keyword arguments. - -# Keyword Arguments - -- `u0`, `p`: Override values for `state_values(prob)` and `parameter_values(prob)` which - should be used instead of the ones in `prob`. -""" -function get_updated_symbolic_problem(indp, prob; kw...) - return prob -end - function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing, kwargs...) - kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle + kwargshandle = kwargshandle === nothing ? SciMLBase.KeywordArgError : kwargshandle kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? _prob.kwargs[:kwargshandle] : kwargshandle @@ -613,7 +83,7 @@ end function solve_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing, kwargs...) - kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle + kwargshandle = kwargshandle === nothing ? SciMLBase.KeywordArgError : kwargshandle kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? _prob.kwargs[:kwargshandle] : kwargshandle @@ -719,7 +189,7 @@ function build_null_solution(prob::AbstractDEProblem, args...; save_everystep = true, save_on = true, save_start = save_everystep || isempty(saveat) || - saveat isa Number || prob.tspan[1] in saveat, + saveat isa Number || prob.tspan[1] in saveat, save_end = true, kwargs...) ts = if saveat === () @@ -752,7 +222,7 @@ function build_null_solution( save_everystep = true, save_on = true, save_start = save_everystep || isempty(saveat) || - saveat isa Number || prob.tspan[1] in saveat, + saveat isa Number || prob.tspan[1] in saveat, save_end = true, kwargs...) prob, success = hack_null_solution_init(prob) @@ -1229,22 +699,6 @@ function solve(prob::AbstractJumpProblem, args...; kwargs...) __solve(prob, args...; kwargs...) end -function checkkwargs(kwargshandle; kwargs...) - if any(x -> x ∉ allowedkeywords, keys(kwargs)) - if kwargshandle == KeywordArgError - throw(CommonKwargError(kwargs)) - elseif kwargshandle == KeywordArgWarn - @warn KWARGWARN_MESSAGE - unrecognized = setdiff(keys(kwargs), allowedkeywords) - print("Unrecognized keyword arguments: ") - printstyled(unrecognized; bold = true, color = :red) - print("\n\n") - else - @assert kwargshandle == KeywordArgSilent - end - end -end - function get_concrete_problem(prob::AbstractJumpProblem, isadapt; kwargs...) get_updated_symbolic_problem(SciMLBase.get_root_indp(prob), prob; kwargs...) end @@ -1430,100 +884,6 @@ function get_concrete_tspan(prob, isadapt, kwargs, p) tspan end -function isconcreteu0(prob, t0, kwargs) - !eval_u0(prob.u0) && prob.u0 !== nothing && !isdistribution(prob.u0) -end - -function isconcretedu0(prob, t0, kwargs) - !eval_u0(prob.u0) && prob.du0 !== nothing && !isdistribution(prob.du0) -end - -function get_concrete_u0(prob, isadapt, t0, kwargs) - if eval_u0(prob.u0) - u0 = prob.u0(prob.p, t0) - elseif haskey(kwargs, :u0) - u0 = kwargs[:u0] - else - u0 = prob.u0 - end - - isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) - - _u0 = handle_distribution_u0(u0) - - if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) - throw(IncompatibleInitialConditionError()) - end - - nu0 = length(something(_u0, ())) - if isdefined(prob.f, :mass_matrix) && prob.f.mass_matrix !== nothing && - prob.f.mass_matrix isa AbstractArray && - size(prob.f.mass_matrix, 1) !== nu0 - throw(IncompatibleMassMatrixError(size(prob.f.mass_matrix, 1), nu0)) - end - - if _u0 isa Tuple - throw(TupleStateError()) - end - - _u0 -end - -function get_concrete_u0(prob::BVProblem, isadapt, t0, kwargs) - if haskey(kwargs, :u0) - u0 = kwargs[:u0] - else - u0 = prob.u0 - end - - isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) - - _u0 = handle_distribution_u0(u0) - - if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) - throw(IncompatibleInitialConditionError()) - end - - if _u0 isa Tuple - throw(TupleStateError()) - end - - return _u0 -end - -function get_concrete_du0(prob, isadapt, t0, kwargs) - if eval_u0(prob.du0) - du0 = prob.du0(prob.p, t0) - elseif haskey(kwargs, :du0) - du0 = kwargs[:du0] - else - du0 = prob.du0 - end - - isadapt && eltype(du0) <: Integer && (du0 = float.(du0)) - - _du0 = handle_distribution_u0(du0) - - if isinplace(prob) && (_du0 isa Number || _du0 isa SArray) - throw(IncompatibleInitialConditionError()) - end - - _du0 -end - -function get_concrete_p(prob, kwargs) - if haskey(kwargs, :p) - p = kwargs[:p] - else - p = prob.p - end -end - -handle_distribution_u0(_u0) = _u0 - -eval_u0(u0::Function) = true -eval_u0(u0) = false - function __solve( prob::AbstractDEProblem, args...; default_set = false, second_time = false, kwargs...) @@ -1615,23 +975,6 @@ function check_prob_alg_pairing(prob, alg) end end -@inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) - if isempty(solve_args) || isnothing(first(solve_args)) - if haskey(solve_kwargs, :alg) - solve_kwargs[:alg] - elseif haskey(prob_kwargs, :alg) - prob_kwargs[:alg] - else - nothing - end - elseif first(solve_args) isa SciMLBase.AbstractSciMLAlgorithm && - !(first(solve_args) isa SciMLBase.EnsembleAlgorithm) - first(solve_args) - else - nothing - end -end - ################### Differentiation """ @@ -1714,40 +1057,3 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba _concrete_solve_forward(_prob, alg, sensealg, u0, p, originator; kwargs...) end end - -#### -# Catch undefined AD overload cases - -const ADJOINT_NOT_FOUND_MESSAGE = """ - Compatibility with reverse-mode automatic differentiation requires SciMLSensitivity.jl. - Please install SciMLSensitivity.jl and do `using SciMLSensitivity`/`import SciMLSensitivity` - for this functionality. For more details, see https://sensitivity.sciml.ai/dev/. - """ - -struct AdjointNotFoundError <: Exception end - -function Base.showerror(io::IO, e::AdjointNotFoundError) - print(io, ADJOINT_NOT_FOUND_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -function _concrete_solve_adjoint(args...; kwargs...) - throw(AdjointNotFoundError()) -end - -const FORWARD_SENSITIVITY_NOT_FOUND_MESSAGE = """ - Compatibility with forward-mode automatic differentiation requires SciMLSensitivity.jl. - Please install SciMLSensitivity.jl and do `using SciMLSensitivity`/`import SciMLSensitivity` - for this functionality. For more details, see https://sensitivity.sciml.ai/dev/. - """ - -struct ForwardSensitivityNotFoundError <: Exception end - -function Base.showerror(io::IO, e::ForwardSensitivityNotFoundError) - print(io, FORWARD_SENSITIVITY_NOT_FOUND_MESSAGE) - println(io, TruncatedStacktraces.VERBOSE_MSG) -end - -function _concrete_solve_forward(args...; kwargs...) - throw(ForwardSensitivityNotFoundError()) -end diff --git a/src/utils.jl b/src/utils.jl index d5c316348..15efcf1fe 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,28 +1,3 @@ -# Handled in Extensions -value(x) = x -unitfulvalue(x) = x -isdistribution(u0) = false -sse(x::Number) = abs2(x) - -# Static Arrays don't support the `init` keyword argument for `sum` -@inline __sum(f::F, args...; init, kwargs...) where {F} = sum(f, args...; init, kwargs...) -@inline function __sum( - f::F, a::StaticArraysCore.StaticArray...; init, kwargs...) where {F} - return mapreduce(f, +, a...; init, kwargs...) -end - -totallength(x::Number) = 1 -totallength(x::AbstractArray) = __sum(totallength, x; init = 0) - -_vec(v) = vec(v) -_vec(v::Number) = v -_vec(v::AbstractSciMLScalarOperator) = v -_vec(v::AbstractVector) = v - -_reshape(v, siz) = reshape(v, siz) -_reshape(v::Number, siz) = v -_reshape(v::AbstractSciMLScalarOperator, siz) = v - macro tight_loop_macros(ex) :($(esc(ex))) end @@ -129,64 +104,3 @@ end @inline __add_and_norm(::typeof(Base.Fix2(norm, Inf)), x, y) = __maximum_abs(+, x, y) @inline __add_and_norm(f::F, x, y) where {F} = __norm_op(f, +, x, y) -struct DualEltypeChecker{T, T2} - x::T - counter::T2 -end - -anyeltypedual(x) = anyeltypedual(x, Val{0}) -anyeltypedual(x, counter) = Any - -function promote_u0(u0, p, t0) - if SciMLStructures.isscimlstructure(p) - _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] - if !isequal(_p, p) - return promote_u0(u0, _p, t0) - end - end - Tu = eltype(u0) - if isdualtype(Tu) - return u0 - end - Tp = anyeltypedual(p, Val{0}) - if Tp == Any - Tp = Tu - end - Tt = anyeltypedual(t0, Val{0}) - if Tt == Any - Tt = Tu - end - Tcommon = promote_type(Tu, Tp, Tt) - return if isdualtype(Tcommon) - Tcommon.(u0) - else - u0 - end -end - -function promote_u0(u0::AbstractArray{<:Complex}, p, t0) - if SciMLStructures.isscimlstructure(p) - _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] - if !isequal(_p, p) - return promote_u0(u0, _p, t0) - end - end - Tu = real(eltype(u0)) - if isdualtype(Tu) - return u0 - end - Tp = anyeltypedual(p, Val{0}) - if Tp == Any - Tp = Tu - end - Tt = anyeltypedual(t0, Val{0}) - if Tt == Any - Tt = Tu - end - Tcommon = promote_type(eltype(u0), Tp, Tt) - return if isdualtype(real(Tcommon)) - Tcommon.(u0) - else - u0 - end -end diff --git a/test/downstream/kwarg_warn.jl b/test/downstream/kwarg_warn.jl index 32e77976a..879971dd9 100644 --- a/test/downstream/kwarg_warn.jl +++ b/test/downstream/kwarg_warn.jl @@ -1,4 +1,5 @@ using OrdinaryDiffEq, Test +using DiffEqBase function lorenz(du, u, p, t) du[1] = 10.0(u[2] - u[1]) du[2] = u[1] * (28.0 - u[3]) - u[2] @@ -8,10 +9,10 @@ u0 = [1.0; 0.0; 0.0] tspan = (0.0, 100.0) prob = ODEProblem(lorenz, u0, tspan) @test_nowarn sol = solve(prob, Tsit5(), reltol = 1e-6) -sol = solve(prob, Tsit5(), rel_tol = 1e-6, kwargshandle = DiffEqBase.KeywordArgWarn) -@test_logs (:warn, DiffEqBase.KWARGWARN_MESSAGE) sol=solve( - prob, Tsit5(), rel_tol = 1e-6, kwargshandle = DiffEqBase.KeywordArgWarn) -@test_throws DiffEqBase.CommonKwargError sol=solve(prob, Tsit5(), rel_tol = 1e-6) +sol = solve(prob, Tsit5(), rel_tol = 1e-6, kwargshandle = SciMLBase.KeywordArgWarn) +@test_logs (:warn, SciMLBase.KWARGWARN_MESSAGE) sol=solve( + prob, Tsit5(), rel_tol = 1e-6, kwargshandle = SciMLBase.KeywordArgWarn) +@test_throws SciMLBase.CommonKwargError sol=solve(prob, Tsit5(), rel_tol = 1e-6) -prob = ODEProblem(lorenz, u0, tspan, test = 2.0, kwargshandle = DiffEqBase.KeywordArgWarn) -@test_logs (:warn, DiffEqBase.KWARGWARN_MESSAGE) sol=solve(prob, Tsit5(), reltol = 1e-6) +prob = ODEProblem(lorenz, u0, tspan, test = 2.0, kwargshandle = SciMLBase.KeywordArgWarn) +@test_logs (:warn, SciMLBase.KWARGWARN_MESSAGE) sol=solve(prob, Tsit5(), reltol = 1e-6) diff --git a/test/downstream/solve_error_handling.jl b/test/downstream/solve_error_handling.jl index dae035e14..eab1fadf4 100644 --- a/test/downstream/solve_error_handling.jl +++ b/test/downstream/solve_error_handling.jl @@ -10,7 +10,7 @@ function f(du, u, p, t) du .= 2.0 * u end prob = ODEProblem(f, u0, tspan) -@test_throws DiffEqBase.IncompatibleInitialConditionError sol=solve(prob, Tsit5()) +@test_throws SciMLBase.IncompatibleInitialConditionError sol=solve(prob, Tsit5()) prob = ODEProblem{false}(f, u0, tspan) sol = solve(prob, Tsit5()) @@ -18,40 +18,40 @@ sol = solve(prob, nothing, alg = Tsit5()) sol = init(prob, nothing, alg = Tsit5()) prob = ODEProblem{false}(f, 1.0 + im, tspan) -@test_throws DiffEqBase.ComplexSupportError solve(prob, CVODE_Adams()) +@test_throws SciMLBase.ComplexSupportError solve(prob, CVODE_Adams()) -@test_throws DiffEqBase.ProblemSolverPairingError solve(prob, DFBDF()) -@test_throws DiffEqBase.NonSolverError solve(prob, 5.0) +@test_throws SciMLBase.ProblemSolverPairingError solve(prob, DFBDF()) +@test_throws SciMLBase.NonSolverError solve(prob, 5.0) prob = ODEProblem{false}(f, u0, (nothing, nothing)) -@test_throws DiffEqBase.NoTspanError solve(prob, Tsit5()) +@test_throws SciMLBase.NoTspanError solve(prob, Tsit5()) prob = ODEProblem{false}(f, u0, (NaN, 1.0)) -@test_throws DiffEqBase.NaNTspanError solve(prob, Tsit5()) +@test_throws SciMLBase.NaNTspanError solve(prob, Tsit5()) prob = ODEProblem{false}(f, u0, (1.0, NaN)) -@test_throws DiffEqBase.NaNTspanError solve(prob, Tsit5()) +@test_throws SciMLBase.NaNTspanError solve(prob, Tsit5()) prob = ODEProblem{false}(f, Any[1.0, 1.0f0], tspan) -@test_throws DiffEqBase.NonConcreteEltypeError solve(prob, Tsit5()) +@test_throws SciMLBase.NonConcreteEltypeError solve(prob, Tsit5()) prob = ODEProblem{false}(f, (1.0, 1.0f0), tspan) -@test_throws DiffEqBase.TupleStateError solve(prob, Tsit5()) +@test_throws SciMLBase.TupleStateError solve(prob, Tsit5()) prob = ODEProblem{false}(f, u0, (0.0 + im, 1.0)) -@test_throws DiffEqBase.ComplexTspanError solve(prob, Tsit5()) +@test_throws SciMLBase.ComplexTspanError solve(prob, Tsit5()) for u0 in ([0.0, 0.0], nothing) fmm = ODEFunction(f, mass_matrix = zeros(3, 3)) prob = ODEProblem(fmm, u0, (0.0, 1.0)) - @test_throws DiffEqBase.IncompatibleMassMatrixError solve(prob, Tsit5()) + @test_throws SciMLBase.IncompatibleMassMatrixError solve(prob, Tsit5()) end # Allow empty mass matrix for empty u0 fmm = ODEFunction((du, u, t) -> nothing, mass_matrix = zeros(0, 0)) prob = ODEProblem(fmm, nothing, (0.0, 1.0)) sol = solve(prob, Tsit5()) -@test isa(sol, DiffEqBase.ODESolution) +@test isa(sol, SciMLBase.ODESolution) f(du, u, p, t) = du .= 1.01u function g(du, u, p, t) @@ -71,7 +71,7 @@ prob = SDEProblem(f, (0.0, 1.0), noise_rate_prototype = complex(zeros(2, 4)), noise = StochasticDiffEq.RealWienerProcess(0.0, zeros(3))) -@test_throws DiffEqBase.NoiseSizeIncompatabilityError solve(prob, LambaEM()) +@test_throws SciMLBase.NoiseSizeIncompatabilityError solve(prob, LambaEM()) function g!(du, u, p, t) du[1] .= u[1] + ones(3, 3) @@ -79,4 +79,4 @@ function g!(du, u, p, t) end u0 = [zeros(3, 3), zeros(3, 3)] prob = ODEProblem(g!, u0, (0, 1.0)) -@test_throws DiffEqBase.NonNumberEltypeError solve(prob, Tsit5()) +@test_throws SciMLBase.NonNumberEltypeError solve(prob, Tsit5()) diff --git a/test/downstream/unitful.jl b/test/downstream/unitful.jl index 89bf38836..68caf0a68 100644 --- a/test/downstream/unitful.jl +++ b/test/downstream/unitful.jl @@ -4,6 +4,6 @@ prob = ODEProblem(f, [2.0u"m"], (0.0u"s", Inf * u"s")) intg = init(prob, Tsit5()) @test_nowarn step!(intg, 0.02u"s", true) -@test DiffEqBase.unitfulvalue(u"1/s") == u"1/s" -@test DiffEqBase.value(ForwardDiff.Dual(1) * u"1/s") == 1 -@test DiffEqBase.unitfulvalue(ForwardDiff.Dual(1) * u"1/s") == u"1/s" +@test SciMLBase.unitfulvalue(u"1/s") == u"1/s" +@test SciMLBase.value(ForwardDiff.Dual(1) * u"1/s") == 1 +@test SciMLBase.unitfulvalue(ForwardDiff.Dual(1) * u"1/s") == u"1/s"