diff --git a/Project.toml b/Project.toml index c27f3a7cf..c803a174b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLBase" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" authors = ["Chris Rackauckas and contributors"] -version = "2.111.1" +version = "2.112.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -36,22 +36,35 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" [weakdeps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" +Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" +MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" RCall = "6f49c342-dc21-5d91-9882-a32aef131414" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" +SciMLBaseDistributionsExt = "Distributions" +SciMLBaseForwardDiffExt = "ForwardDiff" SciMLBaseMLStyleExt = "MLStyle" SciMLBaseMakieExt = "Makie" +SciMLBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements" +SciMLBaseMooncakeExt = "Mooncake" SciMLBasePartialFunctionsExt = "PartialFunctions" SciMLBasePyCallExt = "PyCall" SciMLBasePythonCallExt = "PythonCall" SciMLBaseRCallExt = "RCall" +SciMLBaseReverseDiffExt = "ReverseDiff" +SciMLBaseTrackerExt = "Tracker" SciMLBaseZygoteExt = ["Zygote", "ChainRulesCore"] [compat] @@ -64,6 +77,7 @@ ChainRulesCore = "1.18" CommonSolve = "0.2.4" ConstructionBase = "1.5" Distributed = "1.10" +Distributions = "0.25" DocStringExtensions = "0.9" EnumX = "1" ForwardDiff = "0.10.36, 1" @@ -74,6 +88,9 @@ Logging = "1.10" MLStyle = "0.4.17" Makie = "0.20, 0.21, 0.22, 0.23, 0.24" Markdown = "1.10" +Measurements = "2" +MonteCarloMeasurements = "1" +Mooncake = "0.4" Moshi = "0.3" PartialFunctions = "1.1" PreallocationTools = "0.4.31" @@ -86,6 +103,7 @@ RCall = "0.14.0" RecipesBase = "1.3.4" RecursiveArrayTools = "3.35" Reexport = "1" +ReverseDiff = "1" RuntimeGeneratedFunctions = "0.5.12" SciMLOperators = "1.3" SciMLStructures = "1.1" @@ -95,6 +113,7 @@ StaticArraysCore = "1.4" Statistics = "1.10" SymbolicIndexingInterface = "0.3.36" Tables = "1.11" +Tracker = "0.2" Zygote = "0.7.10" julia = "1.10" diff --git a/ext/SciMLBaseDistributionsExt.jl b/ext/SciMLBaseDistributionsExt.jl new file mode 100644 index 000000000..d6a38e8bc --- /dev/null +++ b/ext/SciMLBaseDistributionsExt.jl @@ -0,0 +1,8 @@ +module SciMLBaseDistributionsExt + +using Distributions, SciMLBase + +SciMLBase.handle_distribution_u0(_u0::Distributions.Sampleable) = rand(_u0) +SciMLBase.isdistribution(_u0::Distributions.Sampleable) = true + +end \ No newline at end of file diff --git a/ext/SciMLBaseForwardDiffExt.jl b/ext/SciMLBaseForwardDiffExt.jl new file mode 100644 index 000000000..dfbbc438e --- /dev/null +++ b/ext/SciMLBaseForwardDiffExt.jl @@ -0,0 +1,410 @@ +module SciMLBaseForwardDiffExt + +using SciMLBase, ForwardDiff +using ArrayInterface + +import SciMLBase: + wrapfun_oop, wrapfun_iip, isdualtype, value, DualEltypeChecker, + AbstractTimeseriesSolution, NonlinearProblem, NonlinearLeastSquaresProblem, + ODEProblem, SDEProblem, RODEProblem, DDEProblem, PDEProblem, DAEProblem, + RecursiveArrayTools, totallength, sse, anyeltypedual + +eltypedual(x) = eltype(x) <: ForwardDiff.Dual +isdualtype(::Type{<:ForwardDiff.Dual}) = true + +promote_dual(::Type{T}, ::Type{T2}) where {T <: ForwardDiff.Dual, T2} = T +function promote_dual(::Type{T}, + ::Type{T2}) where {T <: ForwardDiff.Dual, T2 <: ForwardDiff.Dual} + T +end +promote_dual(::Type{T}, ::Type{T2}) where {T, T2 <: ForwardDiff.Dual} = T2 + +function promote_dual(::Type{T}, + ::Type{T2}) where {T3, T4, V, V2 <: ForwardDiff.Dual, N, N2, + T <: ForwardDiff.Dual{T3, V, N}, + T2 <: ForwardDiff.Dual{T4, V2, N2}} + T2 +end +function promote_dual(::Type{T}, + ::Type{T2}) where {T3, T4, V <: ForwardDiff.Dual, V2, N, N2, + T <: ForwardDiff.Dual{T3, V, N}, + T2 <: ForwardDiff.Dual{T4, V2, N2}} + T +end +function promote_dual(::Type{T}, + ::Type{T2}) where { + T3, V <: ForwardDiff.Dual, V2 <: ForwardDiff.Dual, + N, + T <: ForwardDiff.Dual{T3, V, N}, + T2 <: ForwardDiff.Dual{T3, V2, N}} + ForwardDiff.Dual{T3, promote_dual(V, V2), N} +end + +""" + promote_dual(::Type{T},::Type{T2}) + +Is like the number promotion system, but always prefers a dual number type above +anything else. For higher order differentiation, it returns the most dualiest of +them all. This is then used to promote `u0` into the suspected highest differentiation +space for solving the equation. +""" +promote_dual(::Type{T}, ::Type{T2}) where {T, T2} = T + +# `reduce` and `map` are specialized on tuples to be unrolled (via recursion) +# Therefore, they can be type stable even with heterogeneous input types. +# We also don't care about allocating any temporaries with them, as it should +# all be unrolled and optimized away. +# Being unrolled also means const prop can work for things like +# `mapreduce(f, op, propertynames(x))` +# where `f` may call `getproperty` and thus have return type dependent +# on the particular symbol. +# `mapreduce` hasn't received any such specialization. +@inline diffeqmapreduce(f::F, op::OP, x::Tuple) where {F, OP} = reduce_tup(op, map(f, x)) +@inline function diffeqmapreduce(f::F, op::OP, x::NamedTuple) where {F, OP} + reduce_tup(op, map(f, x)) +end +# For other container types, we probably just want to call `mapreduce` +@inline diffeqmapreduce(f::F, op::OP, x) where {F, OP} = mapreduce(f, op, x, init = Any) + +getval(::Val{I}) where {I} = I +getval(::Type{Val{I}}) where {I} = I +getval(I::Int) = I + +const DUALCHECK_RECURSION_MAX = 10 + +function (dec::DualEltypeChecker)(::Val{Y}) where {Y} + isdefined(dec.x, Y) || return Any + getval(dec.counter) >= DUALCHECK_RECURSION_MAX && return Any + anyeltypedual(getfield(dec.x, Y), Val{getval(dec.counter)}) +end + +# Untyped dispatch: catch composite types, check all of their fields +""" + anyeltypedual(x) + +Searches through a type to see if any of its values are parameters. This is used to +then promote other values to match the dual type. For example, if a user passes a parameter + +which is a `Dual` and a `u0` which is a `Float64`, after the first time step, `f(u,p,t) = p*u` +will change `u0` from `Float64` to `Dual`. Thus the state variable always needs to be converted +to a dual number before the solve. Worse still, this needs to be done in the case of +`f(du,u,p,t) = du[1] = p*u[1]`, and thus running `f` and taking the return value is not a valid +way to calculate the required state type. + +But given the properties of automatic differentiation requiring that differentiation of parameters +implies differentiation of state, we assume any dual parameters implies differentiation of state +and then attempt to upconvert `u0` to match that dual-ness. Because this changes types, this needs +to be specified at compiled time and thus cannot have a Bool-based opt out, so in the future this +may be extended to use a preference system to opt-out with a `UPCONVERT_DUALS`. In the case where +upconversion is not done automatically, the user is required to upconvert all initial conditions +themselves, for an example of how this can be confusing to a user see +https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a-forced-differential-equation/82937 +""" +@generated function anyeltypedual(x, ::Type{Val{counter}}) where {counter} + x = x.name === Core.Compiler.typename(Type) ? x.parameters[1] : x + if isdualtype(x) + :($x) + elseif fieldnames(x) === () + :(Any) + elseif counter < DUALCHECK_RECURSION_MAX + T = diffeqmapreduce(x -> anyeltypedual(x, Val{counter + 1}), promote_dual, + x.parameters) + if T === Any || isconcretetype(T) + :($T) + else + :(diffeqmapreduce(DualEltypeChecker($x, $counter + 1), promote_dual, + map(Val, fieldnames($(typeof(x)))))) + end + else + :(Any) + end +end + +const FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE = """ + Failed to automatically detect ForwardDiff compatability of + the parameter object. In order for ForwardDiff.jl automatic + differentiation to work on a solution object, the state of + the differential equation or nonlinear solve (`u0`) needs to + be converted to a Dual type which matches the values being + differentiated. For example, for a loss function loss(p) + where `p`` is a `Vector{Float64}`, this conversion is + equivalent to: + + ```julia + # Convert u0 to match the new Dual element type of `p` + _prob = remake(prob, u0 = eltype(p).(prob.u0)) + ``` + + In most cases, SciML tools are able to do this conversion + automatically. However, it seems you have provided a + parameter type for which this automatic conversion has failed. + + To fix this, you can do the conversion yourself. For example, + if you have a parameter vector being optimized `p` which is + then put into an odd struct, you can manually convert `u0` + to match `p`: + + ```julia + function loss(p) + _prob = remake(prob, u0 = eltype(p).(prob.u0), p = MyStruct(p)) + sol = solve(_prob, ...) + # do stuff on sol + end + ``` + + Or you can define a dispatch on `DiffEqBase.anyeltypedual` + which tells the system what fields to interpret as the + differentiable parts. For example, to support ODESolutions + as parameters we tell it the data is `sol.u` and `sol.t` via: + + ```julia + function DiffEqBase.anyeltypedual(sol::ODESolution, counter = 0) + DiffEqBase.anyeltypedual((sol.u, sol.t)) + end + ``` + + To opt a type out of the dual checking, define an overload + that returns Any. For example: + + ```julia + function DiffEqBase.anyeltypedual(::YourType, ::Type{Val{counter}}) where {counter} + Any + end + ``` + + If you have defined this on a common type which should + be more generally supported, please open a pull request + adding this dispatch. If you need help defining this dispatch, + feel free to open an issue. + """ + +struct ForwardDiffAutomaticDetectionFailure <: Exception end + +function Base.showerror(io::IO, e::ForwardDiffAutomaticDetectionFailure) + print(io, FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE) +end + +function anyeltypedual(::Type{Union{}}) + throw(ForwardDiffAutomaticDetectionFailure()) +end + +function anyeltypedual(::Type{<:AbstractTimeseriesSolution{T, N}}, + ::Type{Val{counter}} = Val{0}) where {T, N, counter} + anyeltypedual(T) +end + +function anyeltypedual( + ::Type{T}, + ::Type{Val{counter}} = Val{0}) where {counter} where {T <: + NonlinearProblem{ + uType, iip, pType}} where {uType, iip, pType} + return anyeltypedual((uType, pType), Val{counter}) +end + +function anyeltypedual( + ::Type{T}, + ::Type{Val{counter}} = Val{0}) where {counter} where {T <: + NonlinearLeastSquaresProblem{ + uType, iip, pType}} where {uType, iip, pType} + return anyeltypedual((uType, pType), Val{counter}) +end + +function anyeltypedual(x::SciMLBase.RecipesBase.AbstractPlot, + ::Type{Val{counter}} = Val{0}) where {counter} + Any +end +function anyeltypedual(x::Returns, ::Type{Val{counter}} = Val{0}) where {counter} + anyeltypedual(x.value, Val{counter}) +end + +Base.@assume_effects :foldable function __anyeltypedual(::Type{T}) where {T} + if T isa Union + promote_dual(anyeltypedual(T.a), anyeltypedual(T.b)) + elseif hasproperty(T, :parameters) + mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any) + else + T + end +end +function anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T} + __anyeltypedual(T) +end + +function anyeltypedual(::Type{T}, + ::Type{Val{counter}} = Val{0}) where {counter} where {T <: + Union{AbstractArray, Set}} + anyeltypedual(eltype(T)) +end +Base.@pure function __anyeltypedual_ntuple(::Type{T}) where {T <: NTuple} + if isconcretetype(eltype(T)) + return eltype(T) + end + if isempty(T.parameters) + Any + else + mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any) + end +end +function anyeltypedual( + ::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T <: NTuple} + __anyeltypedual_ntuple(T) +end + +# Any in this context just means not Dual +function anyeltypedual( + x::SciMLBase.NullParameters, ::Type{Val{counter}} = Val{0}) where {counter} + Any +end + +function anyeltypedual(sol::RecursiveArrayTools.AbstractDiffEqArray, counter = 0) + diffeqmapreduce(anyeltypedual, promote_dual, (sol.u, sol.t)) +end + +function anyeltypedual(prob::Union{ODEProblem, SDEProblem, RODEProblem, DDEProblem}, + ::Type{Val{counter}} = Val{0}) where {counter} + anyeltypedual((prob.u0, prob.p, prob.tspan)) +end + +function anyeltypedual( + prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem, OptimizationProblem}, + ::Type{Val{counter}} = Val{0}) where {counter} + anyeltypedual((prob.u0, prob.p)) +end + +function anyeltypedual(x::Number, ::Type{Val{counter}} = Val{0}) where {counter} + anyeltypedual(typeof(x)) +end +function anyeltypedual( + x::Union{String, Symbol}, ::Type{Val{counter}} = Val{0}) where {counter} + typeof(x) +end +function anyeltypedual(x::Union{AbstractArray{T}, Set{T}}, + ::Type{Val{counter}} = Val{0}) where {counter} where { + T <: + Union{Number, + Symbol, + String}} + anyeltypedual(T) +end +function anyeltypedual(x::Union{AbstractArray{T}, Set{T}}, + ::Type{Val{counter}} = Val{0}) where {counter} where { + T <: Union{ + AbstractArray{ + <:Number, + }, + Set{ + <:Number, + }}} + anyeltypedual(eltype(x)) +end +function anyeltypedual(x::Union{AbstractArray{T}, Set{T}}, + ::Type{Val{counter}} = Val{0}) where {counter} where {N, T <: NTuple{N, <:Number}} + anyeltypedual(eltype(x)) +end + +# Try to avoid this dispatch because it can lead to type inference issues when !isconcrete(eltype(x)) +function anyeltypedual(x::AbstractArray, ::Type{Val{counter}} = Val{0}) where {counter} + if isconcretetype(eltype(x)) + anyeltypedual(eltype(x)) + elseif !isempty(x) && all(i -> isassigned(x, i), 1:length(x)) && + counter < DUALCHECK_RECURSION_MAX + _counter = Val{counter + 1} + mapreduce(y -> anyeltypedual(y, _counter), promote_dual, x) + else + # This fallback to Any is required since otherwise we cannot handle `undef` in all cases + # misses cases of + Any + end +end + +function anyeltypedual(x::Set, ::Type{Val{counter}} = Val{0}) where {counter} + if isconcretetype(eltype(x)) + anyeltypedual(eltype(x)) + else + # This fallback to Any is required since otherwise we cannot handle `undef` in all cases + Any + end +end + +function anyeltypedual(x::Tuple, ::Type{Val{counter}} = Val{0}) where {counter} + # Handle the empty tuple case separately for inference and to avoid mapreduce error + if x === () + Any + else + diffeqmapreduce(anyeltypedual, promote_dual, x) + end +end +function anyeltypedual(x::AbstractDict, ::Type{Val{counter}} = Val{0}) where {counter} + isempty(x) ? eltype(values(x)) : mapreduce(anyeltypedual, promote_dual, values(x)) +end +function anyeltypedual(x::NamedTuple, ::Type{Val{counter}} = Val{0}) where {counter} + anyeltypedual(values(x)) +end + +function anyeltypedual( + f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}}) where {counter} + Any +end + +anyeltypedual(::@Kwargs{}, ::Type{Val{counter}} = Val{0}) where {counter} = Any +anyeltypedual(::Type{@Kwargs{}}, ::Type{Val{counter}} = Val{0}) where {counter} = Any + +# Opt out since these are using for preallocation, not differentiation +function anyeltypedual(x::Union{ForwardDiff.AbstractConfig, Module}, + ::Type{Val{counter}} = Val{0}) where {counter} + Any +end +function anyeltypedual(x::Type{T}, + ::Type{Val{counter}} = Val{0}) where {counter} where {T <: + ForwardDiff.AbstractConfig} + Any +end + +function anyeltypedual(x::ForwardDiff.DiffResults.DiffResult, + ::Type{Val{counter}} = Val{0}) where {counter} + Any +end +function anyeltypedual(x::Type{T}, + ::Type{Val{counter}} = Val{0}) where {counter} where {T <: + ForwardDiff.DiffResults.DiffResult} + Any +end + +function anyeltypedual(::Type{T}, + ::Type{Val{counter}} = Val{0}) where {counter} where {T <: ForwardDiff.Dual} + T +end + +function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p, tspan, prob, kwargs) + if (haskey(kwargs, :callback) && has_continuous_callback(kwargs[:callback])) || + (haskey(prob.kwargs, :callback) && has_continuous_callback(prob.kwargs[:callback])) + return _promote_tspan(eltype(u0).(tspan), kwargs) + else + return _promote_tspan(tspan, kwargs) + end +end + +function promote_tspan(u0::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, p, tspan, prob, + kwargs) + return _promote_tspan(real(eltype(u0)).(tspan), kwargs) +end + +function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p, + tspan::Tuple{<:ForwardDiff.Dual, <:ForwardDiff.Dual}, prob, kwargs) + return _promote_tspan(tspan, kwargs) +end + +value(x::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = V +value(x::ForwardDiff.Dual) = value(ForwardDiff.value(x)) + +unitfulvalue(x::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = V +unitfulvalue(x::ForwardDiff.Dual) = unitfulvalue(ForwardDiff.value(x)) + +sse(x::ForwardDiff.Dual) = sse(ForwardDiff.value(x)) + sum(sse, ForwardDiff.partials(x)) + +function SciMLBase.totallength(x::ForwardDiff.Dual) + return SciMLBase.totallength(ForwardDiff.value(x)) + + sum(SciMLBase.totallength, ForwardDiff.partials(x)) +end + +end \ No newline at end of file diff --git a/ext/SciMLBaseMeasurementsExt.jl b/ext/SciMLBaseMeasurementsExt.jl new file mode 100644 index 000000000..1b643cf18 --- /dev/null +++ b/ext/SciMLBaseMeasurementsExt.jl @@ -0,0 +1,18 @@ +module SciMLBaseMeasurementsExt + +using Measurements +using SciMLBase: value + +function SciMLBase.promote_u0(u0::AbstractArray{<:Measurements.Measurement}, + p::AbstractArray{<:Measurements.Measurement}, t0) + u0 +end +SciMLBase.promote_u0(u0, p::AbstractArray{<:Measurements.Measurement}, t0) = eltype(p).(u0) + +value(x::Type{Measurements.Measurement{T}}) where {T} = T +value(x::Measurements.Measurement) = Measurements.value(x) + +unitfulvalue(x::Type{Measurements.Measurement{T}}) where {T} = T +unitfulvalue(x::Measurements.Measurement) = Measurements.value(x) + +end diff --git a/ext/SciMLBaseMonteCarloMeasurementsExt.jl b/ext/SciMLBaseMonteCarloMeasurementsExt.jl new file mode 100644 index 000000000..ea1a7c95c --- /dev/null +++ b/ext/SciMLBaseMonteCarloMeasurementsExt.jl @@ -0,0 +1,35 @@ +module SciMLBaseMonteCarloMeasurementsExt + +using SciMLBase +using SciMLBase: value +using MonteCarloMeasurements + +function SciMLBase.promote_u0( + u0::AbstractArray{ + <:MonteCarloMeasurements.AbstractParticles, + }, + p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, + t0) + u0 +end +function SciMLBase.promote_u0(u0, + p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, + t0) + eltype(p).(u0) +end + +function SciMLBase.promote_u0(::Nothing, + p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, + t0) + return nothing +end + +SciMLBase.value(x::Type{MonteCarloMeasurements.AbstractParticles{T, N}}) where {T, N} = T +SciMLBase.value(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles) +function SciMLBase.unitfulvalue(x::Type{MonteCarloMeasurements.AbstractParticles{ + T, N}}) where {T, N} + T +end +SciMLBase.unitfulvalue(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles) + +end \ No newline at end of file diff --git a/ext/SciMLBaseMooncakeExt.jl b/ext/SciMLBaseMooncakeExt.jl new file mode 100644 index 000000000..ddc779823 --- /dev/null +++ b/ext/SciMLBaseMooncakeExt.jl @@ -0,0 +1,25 @@ +module SciMLBaseMooncakeExt + +using SciMLBase, Mooncake +using SciMLBase: ADOriginator, ChainRulesOriginator, MooncakeOriginator +import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive, + @from_rrule, @zero_adjoint, @mooncake_overlay, MinimalCtx, + NoPullback + +@zero_adjoint MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any} +@is_primitive MinimalCtx Tuple{ + typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake), SciMLBase.ChainRulesOriginator +} + +@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = SciMLBase.MooncakeOriginator() + +function rrule!!( + f::CoDual{typeof(SciMLBase.set_mooncakeoriginator_if_mooncake)}, + X::CoDual{SciMLBase.ChainRulesOriginator} +) + return zero_fcodual(SciMLBase.MooncakeOriginator()), NoPullback(f, X) +end + + + +end \ No newline at end of file diff --git a/ext/SciMLBaseReverseDiffExt.jl b/ext/SciMLBaseReverseDiffExt.jl new file mode 100644 index 000000000..6aad9e952 --- /dev/null +++ b/ext/SciMLBaseReverseDiffExt.jl @@ -0,0 +1,57 @@ +module SciMLBaseReverseDiffExt + +using SciMLBase +using ReverseDiff + +function SciMLBase.anyeltypedual(::Type{T}, + ::Type{Val{counter}} = Val{0}) where {counter} where { + V, D, N, VA, DA, T <: ReverseDiff.TrackedArray{V, D, N, VA, DA}} + SciMLBase.anyeltypedual(V, Val{counter}) +end + +SciMLBase.value(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V +function SciMLBase.value(x::Type{ + ReverseDiff.TrackedArray{V, D, N, VA, DA}, +}) where {V, D, + N, VA, + DA} + Array{V, N} +end +SciMLBase.value(x::ReverseDiff.TrackedReal) = x.value +SciMLBase.value(x::ReverseDiff.TrackedArray) = x.value + +SciMLBase.unitfulvalue(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V +function SciMLBase.unitfulvalue(x::Type{ + ReverseDiff.TrackedArray{V, D, N, VA, DA}, +}) where {V, D, + N, VA, + DA} + Array{V, N} +end +SciMLBase.unitfulvalue(x::ReverseDiff.TrackedReal) = x.value +SciMLBase.unitfulvalue(x::ReverseDiff.TrackedArray) = x.value + +# Force TrackedArray from TrackedReal when reshaping W\b +SciMLBase._reshape(v::AbstractVector{<:ReverseDiff.TrackedReal}, siz) = reduce(vcat, v) + +SciMLBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0 +function SciMLBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, + p::ReverseDiff.TrackedArray, t0) + u0 +end +function SciMLBase.promote_u0(u0::ReverseDiff.TrackedArray, + p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) + u0 +end +function SciMLBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, + p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) + u0 +end +SciMLBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0) +function SciMLBase.promote_u0( + u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ReverseDiff.ForwardDiff.Dual} + ReverseDiff.track(T.(u0)) +end +SciMLBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0) + +end \ No newline at end of file diff --git a/ext/SciMLBaseTrackerExt.jl b/ext/SciMLBaseTrackerExt.jl new file mode 100644 index 000000000..3cc385f13 --- /dev/null +++ b/ext/SciMLBaseTrackerExt.jl @@ -0,0 +1,37 @@ +module SciMLBaseTrackerExt + +using SciMLBase +import Tracker + +SciMLBase.value(x::Type{Tracker.TrackedReal{T}}) where {T} = T +SciMLBase.value(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} = Array{T, N} +SciMLBase.value(x::Tracker.TrackedReal) = x.data +SciMLBase.value(x::Tracker.TrackedArray) = x.data + +SciMLBase.unitfulvalue(x::Type{Tracker.TrackedReal{T}}) where {T} = T +function SciMLBase.unitfulvalue(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} + Array{T, N} +end +SciMLBase.unitfulvalue(x::Tracker.TrackedReal) = x.data +SciMLBase.unitfulvalue(x::Tracker.TrackedArray) = x.data + +SciMLBase.promote_u0(u0::Tracker.TrackedArray, p::Tracker.TrackedArray, t0) = u0 +function SciMLBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, + p::Tracker.TrackedArray, t0) + u0 +end +function SciMLBase.promote_u0(u0::Tracker.TrackedArray, + p::AbstractArray{<:Tracker.TrackedReal}, t0) + u0 +end +function SciMLBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, + p::AbstractArray{<:Tracker.TrackedReal}, t0) + u0 +end +SciMLBase.promote_u0(u0, p::Tracker.TrackedArray, t0) = Tracker.track(u0) +SciMLBase.promote_u0(u0, p::AbstractArray{<:Tracker.TrackedReal}, t0) = eltype(p).(u0) + +@inline Base.any(f::Function, x::Tracker.TrackedArray) = any(f, Tracker.data(x)) + + +end \ No newline at end of file diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 54beb0640..b0f5afec4 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -27,7 +27,7 @@ import Accessors: @set, @reset, @delete, @insert using Moshi.Data: @data using Moshi.Match: @match import Moshi.Derive: @derive -import StaticArraysCore +import StaticArraysCore: StaticArraysCore, SArray import Adapt: adapt_structure, adapt using Reexport @@ -36,7 +36,7 @@ using SciMLOperators: AbstractSciMLOperator, IdentityOperator, NullOperator, ScaledOperator, AddedOperator, ComposedOperator, - InvertedOperator, InvertibleOperator + InvertedOperator, InvertibleOperator, AbstractSciMLScalarOperator import SciMLOperators: DEFAULT_UPDATE_FUNC, update_coefficients, update_coefficients!, @@ -719,6 +719,7 @@ $(TYPEDEF) abstract type AbstractParameterizedFunction{iip} <: AbstractODEFunction{iip} end include("retcodes.jl") +include("errors.jl") include("symbolic_utils.jl") include("performance_warnings.jl") diff --git a/src/errors.jl b/src/errors.jl new file mode 100644 index 000000000..6de0c4965 --- /dev/null +++ b/src/errors.jl @@ -0,0 +1,472 @@ +const allowedkeywords = (:dense, + :saveat, + :save_idxs, + :tstops, + :tspan, + :d_discontinuities, + :save_everystep, + :save_on, + :save_start, + :save_end, + :initialize_save, + :adaptive, + :abstol, + :reltol, + :dt, + :dtmax, + :dtmin, + :force_dtmin, + :internalnorm, + :controller, + :gamma, + :beta1, + :beta2, + :qmax, + :qmin, + :qsteady_min, + :qsteady_max, + :qoldinit, + :failfactor, + :calck, + :alias_u0, + :maxiters, + :maxtime, + :callback, + :isoutofdomain, + :unstable_check, + :verbose, + :merge_callbacks, + :progress, + :progress_steps, + :progress_name, + :progress_message, + :progress_id, + :timeseries_errors, + :dense_errors, + :weak_timeseries_errors, + :weak_dense_errors, + :wrap, + :calculate_error, + :initializealg, + :alg, + :save_noise, + :delta, + :seed, + :alg_hints, + :kwargshandle, + :trajectories, + :batch_size, + :sensealg, + :advance_to_tstop, + :stop_at_next_tstop, + :u0, + :p, + # These two are from the default algorithm handling + :default_set, + :second_time, + # This is for DiffEqDevTools + :prob_choice, + # Jump problems + :alias_jump, + # This is for copying/deepcopying noise in StochasticDiffEq + :alias_noise, + # This is for SimpleNonlinearSolve handling for batched Nonlinear Solves + :batch, + # Shooting method in BVP needs to differentiate between these two categories + :nlsolve_kwargs, + :odesolve_kwargs, + # If Solvers which internally use linsolve + :linsolve_kwargs, + # Solvers internally using EnsembleProblem + :ensemblealg, + # Fine Grained Control of Tracing (Storing and Logging) during Solve + :show_trace, + :trace_level, + :store_trace, + # Termination condition for solvers + :termination_condition, + # For AbstractAliasSpecifier + :alias, + # Parameter estimation with BVP + :fit_parameters) + + +const KWARGWARN_MESSAGE = """ + Unrecognized keyword arguments found. + The only allowed keyword arguments to `solve` are: + $allowedkeywords + + See https://diffeq.sciml.ai/stable/basics/common_solver_opts/ for more details. + + Set kwargshandle=KeywordArgError for an error message. + Set kwargshandle=KeywordArgSilent to ignore this message. + """ + +const KWARGERROR_MESSAGE = """ + Unrecognized keyword arguments found. + The only allowed keyword arguments to `solve` are: + $allowedkeywords + + See https://diffeq.sciml.ai/stable/basics/common_solver_opts/ for more details. + """ + +struct CommonKwargError <: Exception + kwargs::Any +end + +function Base.showerror(io::IO, e::CommonKwargError) + println(io, KWARGERROR_MESSAGE) + notin = collect(map(x -> x ∉ allowedkeywords, keys(e.kwargs))) + unrecognized = collect(keys(e.kwargs))[notin] + print(io, "Unrecognized keyword arguments: ") + printstyled(io, unrecognized; bold = true, color = :red) + print(io, "\n\n") + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +@enum KeywordArgError KeywordArgWarn KeywordArgSilent + +const INCOMPATIBLE_U0_MESSAGE = """ + Initial condition incompatible with functional form. + Detected an in-place function with an initial condition of type Number or SArray. + This is incompatible because Numbers cannot be mutated, i.e. + `x = 2.0; y = 2.0; x .= y` will error. + + If using a immutable initial condition type, please use the out-of-place form. + I.e. define the function `du=f(u,p,t)` instead of attempting to "mutate" the immutable `du`. + + If your differential equation function was defined with multiple dispatches and one is + in-place, then the automatic detection will choose in-place. In this case, override the + choice in the problem constructor, i.e. `ODEProblem{false}(f,u0,tspan,p,kwargs...)`. + + For a longer discussion on mutability vs immutability and in-place vs out-of-place, see: + https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Example-Accelerating-a-Non-Stiff-Equation:-The-Lorenz-Equation + """ + +struct IncompatibleInitialConditionError <: Exception end + +function Base.showerror(io::IO, e::IncompatibleInitialConditionError) + print(io, INCOMPATIBLE_U0_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const NO_DEFAULT_ALGORITHM_MESSAGE = """ + Default algorithm choices require DifferentialEquations.jl. + Please specify an algorithm (e.g., `solve(prob, Tsit5())` or + `init(prob, Tsit5())` for an ODE) or import DifferentialEquations + directly. + + You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ + and its associated pages. + """ + +struct NoDefaultAlgorithmError <: Exception end + +function Base.showerror(io::IO, e::NoDefaultAlgorithmError) + print(io, NO_DEFAULT_ALGORITHM_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const NO_TSPAN_MESSAGE = """ + No tspan is set in the problem or chosen in the init/solve call + """ + +struct NoTspanError <: Exception end + +function Base.showerror(io::IO, e::NoTspanError) + print(io, NO_TSPAN_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const NAN_TSPAN_MESSAGE = """ + NaN tspan is set in the problem or chosen in the init/solve call. + Note that -Inf and Inf values are allowed in the timespan for solves + which are terminated via callbacks, however NaN values are not allowed + since the direction of time is undetermined. + """ + +struct NaNTspanError <: Exception end + +function Base.showerror(io::IO, e::NaNTspanError) + print(io, NAN_TSPAN_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const NON_SOLVER_MESSAGE = """ + The arguments to solve are incorrect. + The second argument must be a solver choice, `solve(prob,alg)` + where `alg` is a `<: AbstractDEAlgorithm`, e.g. `Tsit5()`. + + Please double check the arguments being sent to the solver. + + You can find the list of available solvers at https://diffeq.sciml.ai/stable/solvers/ode_solve/ + and its associated pages. + """ + +struct NonSolverError <: Exception end + +function Base.showerror(io::IO, e::NonSolverError) + print(io, NON_SOLVER_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const NOISE_SIZE_MESSAGE = """ + Noise sizes are incompatible. The expected number of noise terms in the defined + `noise_rate_prototype` does not match the number of noise terms in the defined + `AbstractNoiseProcess`. Please ensure that + size(prob.noise_rate_prototype,2) == length(prob.noise.W[1]). + + Note: Noise process definitions require that users specify `u0`, and this value is + directly used in the definition. For example, if `noise = WienerProcess(0.0,0.0)`, + then the noise process is a scalar with `u0=0.0`. If `noise = WienerProcess(0.0,[0.0])`, + then the noise process is a vector with `u0=0.0`. If `noise_rate_prototype = zeros(2,4)`, + then the noise process must be a 4-dimensional process, for example + `noise = WienerProcess(0.0,zeros(4))`. This error is a sign that the user definition + of `noise_rate_prototype` and `noise` are not aligned in this manner and the definitions should + be double checked. + """ + +struct NoiseSizeIncompatabilityError <: Exception + prototypesize::Int + noisesize::Int +end + +function Base.showerror(io::IO, e::NoiseSizeIncompatabilityError) + println(io, NOISE_SIZE_MESSAGE) + println(io, "size(prob.noise_rate_prototype,2) = $(e.prototypesize)") + println(io, "length(prob.noise.W[1]) = $(e.noisesize)") + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const PROBSOLVER_PAIRING_MESSAGE = """ + Incompatible problem+solver pairing. + For example, this can occur if an ODE solver is passed with an SDEProblem. + Solvers are only capable of handling specific problem types. Please double + check that the chosen pairing is capable for handling the given problems. + """ + +struct ProblemSolverPairingError <: Exception + prob::Any + alg::Any +end + +function Base.showerror(io::IO, e::ProblemSolverPairingError) + println(io, PROBSOLVER_PAIRING_MESSAGE) + println(io, "Problem type: $(SciMLBase.__parameterless_type(typeof(e.prob)))") + println(io, "Solver type: $(SciMLBase.__parameterless_type(typeof(e.alg)))") + println(io, + "Problem types compatible with the chosen solver: $(compatible_problem_types(e.prob,e.alg))") + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +function compatible_problem_types(prob, alg) + if alg isa AbstractODEAlgorithm + ODEProblem + elseif alg isa AbstractSDEAlgorithm + (SDEProblem, SDDEProblem) + elseif alg isa AbstractDDEAlgorithm # StochasticDelayDiffEq.jl just uses the SDE alg + DDEProblem + elseif alg isa AbstractDAEAlgorithm + DAEProblem + elseif alg isa AbstractSteadyStateAlgorithm + SteadyStateProblem + end +end + +const DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE = """ + Incompatible solver + automatic differentiation pairing. + The chosen automatic differentiation algorithm requires the ability + for compiler transforms on the code which is only possible on pure-Julia + solvers such as those from OrdinaryDiffEq.jl. Direct differentiation methods + which require this ability include: + + - Direct use of ForwardDiff.jl on the solver + - `ForwardDiffSensitivity`, `ReverseDiffAdjoint`, `TrackerAdjoint`, and `ZygoteAdjoint` + sensealg choices for adjoint differentiation. + + Either switch the choice of solver to a pure Julia method, or change the automatic + differentiation method to one that does not require such transformations. + + For more details on automatic differentiation, adjoint, and sensitivity analysis + of differential equations, see the documentation page: + + https://diffeq.sciml.ai/stable/analysis/sensitivity/ + """ + +struct DirectAutodiffError <: Exception end + +function Base.showerror(io::IO, e::DirectAutodiffError) + println(io, DIRECT_AUTODIFF_INCOMPATABILITY_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const NONNUMBER_ELTYPE_MESSAGE = """ + Non-Number element type inside of an `Array` detected. + Arrays with non-number element types, such as + `Array{Array{Float64}}`, are not supported by the + solvers. + + If you are trying to use an array of arrays structure, + look at the tools in RecursiveArrayTools.jl. For example: + + If this was a mistake, promote the element types to be + all the same. If this was intentional, for example, + using Unitful.jl with different unit values, then use + an array type which has fast broadcast support for + heterogeneous values such as the ArrayPartition + from RecursiveArrayTools.jl. For example: + + ```julia + using RecursiveArrayTools + u0 = ArrayPartition([1.0,2.0],[3.0,4.0]) + u0 = VectorOfArray([1.0,2.0],[3.0,4.0]) + ``` + + are both initial conditions which would be compatible with + the solvers. Or use ComponentArrays.jl for more complex + nested structures. + + Element type: + """ + +struct NonNumberEltypeError <: Exception + eltype::Any +end + +function Base.showerror(io::IO, e::NonNumberEltypeError) + print(io, NONNUMBER_ELTYPE_MESSAGE) + print(io, e.eltype) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const GENERIC_NUMBER_TYPE_ERROR_MESSAGE = """ + Non-standard number type (i.e. not Float32, Float64, + ComplexF32, or ComplexF64) detected as the element type + for the initial condition or time span. These generic + number types are only compatible with the pure Julia + solvers which support generic programming, such as + OrdinaryDiffEq.jl. The chosen solver does not support + this functionality. Please double check that the initial + condition and time span types are correct, and check that + the chosen solver was correct. + """ + +struct GenericNumberTypeError <: Exception + alg::Any + uType::Any + tType::Any +end + +function Base.showerror(io::IO, e::GenericNumberTypeError) + println(io, GENERIC_NUMBER_TYPE_ERROR_MESSAGE) + println(io, "Solver: $(e.alg)") + println(io, "u0 type: $(e.uType)") + print(io, "Timespan type: $(e.tType)") + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const COMPLEX_SUPPORT_ERROR_MESSAGE = """ + Complex number type (i.e. ComplexF32, or ComplexF64) + detected as the element type for the initial condition + with an algorithm that does not support complex numbers. + Please check that the initial condition type is correct. + If complex number support is needed, try different solvers + such as those from OrdinaryDiffEq.jl. + """ + +struct ComplexSupportError <: Exception + alg::Any +end + +function Base.showerror(io::IO, e::ComplexSupportError) + println(io, COMPLEX_SUPPORT_ERROR_MESSAGE) + println(io, "Solver: $(e.alg)") + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const COMPLEX_TSPAN_ERROR_MESSAGE = """ + Complex number type (i.e. ComplexF32, or ComplexF64) + detected as the element type for the independent variable + (i.e. time span). Please check that the tspan type is correct. + No solvers support complex time spans. If this is required, + please open an issue. + """ + +struct ComplexTspanError <: Exception end + +function Base.showerror(io::IO, e::ComplexTspanError) + println(io, COMPLEX_TSPAN_ERROR_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const TUPLE_STATE_ERROR_MESSAGE = """ + Tuple type used as a state. Since a tuple does not have vector + properties, it will not work as a state type in equation solvers. + Instead, change your equation from using tuple constructors `()` + to static array constructors `SA[]`. For example, change: + + ```julia + function ftup((a,b),p,t) + return b,-a + end + u0 = (1.0,2.0) + tspan = (0.0,1.0) + ODEProblem(ftup,u0,tspan) + ``` + + to: + + ```julia + using StaticArrays + function fsa(u,p,t) + SA[u[2],u[1]] + end + u0 = SA[1.0,2.0] + tspan = (0.0,1.0) + ODEProblem(ftup,u0,tspan) + ``` + + This will be safer and fast for small ODEs. For more information, see: + https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Further-Optimizations-of-Small-Non-Stiff-ODEs-with-StaticArrays + """ + +struct TupleStateError <: Exception end + +function Base.showerror(io::IO, e::TupleStateError) + println(io, TUPLE_STATE_ERROR_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const MASS_MATRIX_ERROR_MESSAGE = """ + Mass matrix size is incompatible with initial condition + sizing. The mass matrix must represent the `vec` + form of the initial condition `u0`, i.e. + `size(mm,1) == size(mm,2) == length(u)` + """ + +struct IncompatibleMassMatrixError <: Exception + sz::Int + len::Int +end + +function Base.showerror(io::IO, e::IncompatibleMassMatrixError) + println(io, MASS_MATRIX_ERROR_MESSAGE) + print(io, "size(prob.f.mass_matrix,1): ") + println(io, e.sz) + print(io, "length(u0): ") + println(e.len) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +const LATE_BINDING_TSTOPS_ERROR_MESSAGE = """ + This solver does not support providing `tstops` as a function. + Consider using a different solver or providing `tstops` as an array + of times. + """ + +struct LateBindingTstopsNotSupportedError <: Exception end + +function Base.showerror(io::IO, e::LateBindingTstopsNotSupportedError) + println(io, LATE_BINDING_TSTOPS_ERROR_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end \ No newline at end of file diff --git a/src/solve.jl b/src/solve.jl index c6e957936..65c7bcc09 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,18 +1,8 @@ -# Skip the DiffEqBase handling - -struct IncompatibleOptimizerError <: Exception - err::String -end - -function Base.showerror(io::IO, e::IncompatibleOptimizerError) - print(io, e.err) -end - const NONCONCRETE_ELTYPE_MESSAGE = """ Non-concrete element type inside of an `Array` detected. Arrays with non-concrete element types, such as `Array{Union{Float32,Float64}}`, are not supported by the - optimizers. Anyways, this is bad for + differential equation solvers. Anyways, this is bad for performance so you don't want to be doing this! If this was a mistake, promote the element types to be @@ -26,7 +16,7 @@ const NONCONCRETE_ELTYPE_MESSAGE = """ using RecursiveArrayTools x = ArrayPartition([1.0,2.0],[1f0,2f0]) y = ArrayPartition([3.0,4.0],[3f0,4f0]) - x .+ y # fast, stable, and usable as u0 in some optimizers + x .+ y # fast, stable, and usable as u0 into DiffEq! ``` Element type: @@ -39,6 +29,17 @@ end function Base.showerror(io::IO, e::NonConcreteEltypeError) print(io, NONCONCRETE_ELTYPE_MESSAGE) print(io, e.eltype) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +# Skip the DiffEqBase handling + +struct IncompatibleOptimizerError <: Exception + err::String +end + +function Base.showerror(io::IO, e::IncompatibleOptimizerError) + print(io, e.err) end """ @@ -246,3 +247,272 @@ end function __solve(prob::OptimizationProblem, alg, args...; kwargs...) throw(OptimizerMissingError(alg)) end + + +# Functions used in solve dispatches + +eltypedual(x) = false +promote_u0(::Nothing, p, t0) = nothing +isdualtype(::Type{T}) where {T} = false + +has_kwargs(_prob::AbstractSciMLProblem) = has_kwargs(typeof(_prob)) +Base.@pure __has_kwargs(::Type{T}) where {T} = :kwargs ∈ fieldnames(T) +has_kwargs(::Type{T}) where {T} = __has_kwargs(T) + +@inline function extract_alg(solve_args, solve_kwargs, prob_kwargs) + if isempty(solve_args) || isnothing(first(solve_args)) + if haskey(solve_kwargs, :alg) + solve_kwargs[:alg] + elseif haskey(prob_kwargs, :alg) + prob_kwargs[:alg] + else + nothing + end + elseif first(solve_args) isa SciMLBase.AbstractSciMLAlgorithm && + !(first(solve_args) isa SciMLBase.EnsembleAlgorithm) + first(solve_args) + else + nothing + end +end + +handle_distribution_u0(_u0) = _u0 + +eval_u0(u0::Function) = true +eval_u0(u0) = false + +function get_concrete_p(prob, kwargs) + if haskey(kwargs, :p) + p = kwargs[:p] + else + p = prob.p + end +end + +function get_concrete_u0(prob::BVProblem, isadapt, t0, kwargs) + if haskey(kwargs, :u0) + u0 = kwargs[:u0] + else + u0 = prob.u0 + end + + isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) + + _u0 = handle_distribution_u0(u0) + + if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) + throw(IncompatibleInitialConditionError()) + end + + if _u0 isa Tuple + throw(TupleStateError()) + end + + return _u0 +end + +function checkkwargs(kwargshandle; kwargs...) + if any(x -> x ∉ allowedkeywords, keys(kwargs)) + if kwargshandle == KeywordArgError + throw(CommonKwargError(kwargs)) + elseif kwargshandle == KeywordArgWarn + @warn KWARGWARN_MESSAGE + unrecognized = setdiff(keys(kwargs), allowedkeywords) + print("Unrecognized keyword arguments: ") + printstyled(unrecognized; bold = true, color = :red) + print("\n\n") + else + @assert kwargshandle == KeywordArgSilent + end + end +end + +""" + $(TYPEDSIGNATURES) + +Given the index provider `indp` used to construct the problem `prob` being solved, return +an updated `prob` to be used for solving. All implementations should accept arbitrary +keyword arguments. + +Should be called before the problem is solved, after performing type-promotion on the +problem. If the returned problem is not `===` the provided `prob`, it is assumed to +contain the `u0` and `p` passed as keyword arguments. + +# Keyword Arguments + +- `u0`, `p`: Override values for `state_values(prob)` and `parameter_values(prob)` which + should be used instead of the ones in `prob`. +""" +function get_updated_symbolic_problem(indp, prob; kw...) + return prob +end + +function isconcreteu0(prob, t0, kwargs) + !eval_u0(prob.u0) && prob.u0 !== nothing && !isdistribution(prob.u0) +end + +function isconcretedu0(prob, t0, kwargs) + !eval_u0(prob.u0) && prob.du0 !== nothing && !isdistribution(prob.du0) +end + +function get_concrete_u0(prob, isadapt, t0, kwargs) + if eval_u0(prob.u0) + u0 = prob.u0(prob.p, t0) + elseif haskey(kwargs, :u0) + u0 = kwargs[:u0] + else + u0 = prob.u0 + end + + isadapt && eltype(u0) <: Integer && (u0 = float.(u0)) + + _u0 = handle_distribution_u0(u0) + + if isinplace(prob) && (_u0 isa Number || _u0 isa SArray) + throw(IncompatibleInitialConditionError()) + end + + nu0 = length(something(_u0, ())) + if isdefined(prob.f, :mass_matrix) && prob.f.mass_matrix !== nothing && + prob.f.mass_matrix isa AbstractArray && + size(prob.f.mass_matrix, 1) !== nu0 + throw(IncompatibleMassMatrixError(size(prob.f.mass_matrix, 1), nu0)) + end + + if _u0 isa Tuple + throw(TupleStateError()) + end + + _u0 +end + +function get_concrete_du0(prob, isadapt, t0, kwargs) + if eval_u0(prob.du0) + du0 = prob.du0(prob.p, t0) + elseif haskey(kwargs, :du0) + du0 = kwargs[:du0] + else + du0 = prob.du0 + end + + isadapt && eltype(du0) <: Integer && (du0 = float.(du0)) + + _du0 = handle_distribution_u0(du0) + + if isinplace(prob) && (_du0 isa Number || _du0 isa SArray) + throw(IncompatibleInitialConditionError()) + end + + _du0 +end + +function promote_u0(u0, p, t0) + if SciMLStructures.isscimlstructure(p) + _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] + if !isequal(_p, p) + return promote_u0(u0, _p, t0) + end + end + Tu = eltype(u0) + if isdualtype(Tu) + return u0 + end + Tp = anyeltypedual(p, Val{0}) + if Tp == Any + Tp = Tu + end + Tt = anyeltypedual(t0, Val{0}) + if Tt == Any + Tt = Tu + end + Tcommon = promote_type(Tu, Tp, Tt) + return if isdualtype(Tcommon) + Tcommon.(u0) + else + u0 + end +end + +function promote_u0(u0::AbstractArray{<:Complex}, p, t0) + if SciMLStructures.isscimlstructure(p) + _p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1] + if !isequal(_p, p) + return promote_u0(u0, _p, t0) + end + end + Tu = real(eltype(u0)) + if isdualtype(Tu) + return u0 + end + Tp = anyeltypedual(p, Val{0}) + if Tp == Any + Tp = Tu + end + Tt = anyeltypedual(t0, Val{0}) + if Tt == Any + Tt = Tu + end + Tcommon = promote_type(eltype(u0), Tp, Tt) + return if isdualtype(real(Tcommon)) + Tcommon.(u0) + else + u0 + end +end + +anyeltypedual(x) = anyeltypedual(x, Val{0}) +anyeltypedual(x, counter) = Any + +value(x) = x +unitfulvalue(x) = x +isdistribution(u0) = false +sse(x::Number) = abs2(x) + +struct DualEltypeChecker{T, T2} + x::T + counter::T2 +end + +@inline __sum(f::F, args...; init, kwargs...) where {F} = sum(f, args...; init, kwargs...) +@inline function __sum( + f::F, a::StaticArraysCore.StaticArray...; init, kwargs...) where {F} + return mapreduce(f, +, a...; init, kwargs...) +end + +totallength(x::Number) = 1 +totallength(x::AbstractArray) = __sum(totallength, x; init = 0) + +_reshape(v, siz) = reshape(v, siz) +_reshape(v::Number, siz) = v +_reshape(v::AbstractSciMLScalarOperator, siz) = v + +set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = x + +# Copied from Static.jl https://github.com/SciML/Static.jl/blob/b50279cc9b33741fd60f382c789fbaef8622d964/src/Static.jl#L743 +@generated function reduce_tup(f::F, inds::Tuple{Vararg{Any, N}}) where {F, N} + q = Expr(:block, Expr(:meta, :inline, :propagate_inbounds)) + if N == 1 + push!(q.args, :(inds[1])) + return q + end + syms = Vector{Symbol}(undef, N) + i = 0 + for n in 1:N + syms[n] = iₙ = Symbol(:i_, (i += 1)) + push!(q.args, Expr(:(=), iₙ, Expr(:ref, :inds, n))) + end + W = 1 << (8sizeof(N) - 2 - leading_zeros(N)) + while W > 0 + _N = length(syms) + for _ in (2W):W:_N + for w in 1:W + new_sym = Symbol(:i_, (i += 1)) + push!(q.args, Expr(:(=), new_sym, Expr(:call, :f, syms[w], syms[w + W]))) + syms[w] = new_sym + end + deleteat!(syms, (1 + W):(2W)) + end + W >>>= 1 + end + q +end \ No newline at end of file