Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,16 @@ function solve_batch(prob, alg, ::EnsembleSplitThreads, II, pmap_batch_size; kwa
end
reduce(vcat, batch_data)
end

function solve(prob::EnsembleProblem, args...; kwargs...)
alg = extract_alg(args, kwargs, kwargs)
if length(args) > 1
__solve(prob, alg, Base.tail(args)...; kwargs...)
else
__solve(prob, alg; kwargs...)
end
end

function solve(prob::SciMLBase.WeightedEnsembleProblem, args...; kwargs...)
WeightedEnsembleSolution(solve(prob.ensembleprob), prob.weights)
end
37 changes: 37 additions & 0 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,4 +515,41 @@ set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = x
W >>>= 1
end
q
end

####
# Catch undefined AD overload cases

const ADJOINT_NOT_FOUND_MESSAGE = """
Compatibility with reverse-mode automatic differentiation requires SciMLSensitivity.jl.
Please install SciMLSensitivity.jl and do `using SciMLSensitivity`/`import SciMLSensitivity`
for this functionality. For more details, see https://sensitivity.sciml.ai/dev/.
"""

struct AdjointNotFoundError <: Exception end

function Base.showerror(io::IO, e::AdjointNotFoundError)
print(io, ADJOINT_NOT_FOUND_MESSAGE)
println(io, TruncatedStacktraces.VERBOSE_MSG)
end

function _concrete_solve_adjoint(args...; kwargs...)
throw(AdjointNotFoundError())
end

const FORWARD_SENSITIVITY_NOT_FOUND_MESSAGE = """
Compatibility with forward-mode automatic differentiation requires SciMLSensitivity.jl.
Please install SciMLSensitivity.jl and do `using SciMLSensitivity`/`import SciMLSensitivity`
for this functionality. For more details, see https://sensitivity.sciml.ai/dev/.
"""

struct ForwardSensitivityNotFoundError <: Exception end

function Base.showerror(io::IO, e::ForwardSensitivityNotFoundError)
print(io, FORWARD_SENSITIVITY_NOT_FOUND_MESSAGE)
println(io, TruncatedStacktraces.VERBOSE_MSG)
end

function _concrete_solve_forward(args...; kwargs...)
throw(ForwardSensitivityNotFoundError())
end
Loading