diff --git a/src/ensemble/basic_ensemble_solve.jl b/src/ensemble/basic_ensemble_solve.jl index c71d56759..f9a1f7a78 100644 --- a/src/ensemble/basic_ensemble_solve.jl +++ b/src/ensemble/basic_ensemble_solve.jl @@ -352,3 +352,16 @@ function solve_batch(prob, alg, ::EnsembleSplitThreads, II, pmap_batch_size; kwa end reduce(vcat, batch_data) end + +function solve(prob::EnsembleProblem, args...; kwargs...) + alg = extract_alg(args, kwargs, kwargs) + if length(args) > 1 + __solve(prob, alg, Base.tail(args)...; kwargs...) + else + __solve(prob, alg; kwargs...) + end +end + +function solve(prob::SciMLBase.WeightedEnsembleProblem, args...; kwargs...) + WeightedEnsembleSolution(solve(prob.ensembleprob), prob.weights) +end \ No newline at end of file diff --git a/src/solve.jl b/src/solve.jl index 65c7bcc09..c3a9cd4dc 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -515,4 +515,41 @@ set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = x W >>>= 1 end q +end + +#### +# Catch undefined AD overload cases + +const ADJOINT_NOT_FOUND_MESSAGE = """ + Compatibility with reverse-mode automatic differentiation requires SciMLSensitivity.jl. + Please install SciMLSensitivity.jl and do `using SciMLSensitivity`/`import SciMLSensitivity` + for this functionality. For more details, see https://sensitivity.sciml.ai/dev/. + """ + +struct AdjointNotFoundError <: Exception end + +function Base.showerror(io::IO, e::AdjointNotFoundError) + print(io, ADJOINT_NOT_FOUND_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +function _concrete_solve_adjoint(args...; kwargs...) + throw(AdjointNotFoundError()) +end + +const FORWARD_SENSITIVITY_NOT_FOUND_MESSAGE = """ + Compatibility with forward-mode automatic differentiation requires SciMLSensitivity.jl. + Please install SciMLSensitivity.jl and do `using SciMLSensitivity`/`import SciMLSensitivity` + for this functionality. For more details, see https://sensitivity.sciml.ai/dev/. + """ + +struct ForwardSensitivityNotFoundError <: Exception end + +function Base.showerror(io::IO, e::ForwardSensitivityNotFoundError) + print(io, FORWARD_SENSITIVITY_NOT_FOUND_MESSAGE) + println(io, TruncatedStacktraces.VERBOSE_MSG) +end + +function _concrete_solve_forward(args...; kwargs...) + throw(ForwardSensitivityNotFoundError()) end \ No newline at end of file