Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,10 @@ function anyeltypedual(x::NamedTuple, ::Type{Val{counter}} = Val{0}) where {coun
anyeltypedual(values(x))
end

DiffEqBase.anyeltypedual(f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}}) where {counter} = Any
function DiffEqBase.anyeltypedual(
f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}}) where {counter}
Any
end

@inline promote_u0(::Nothing, p, t0) = nothing

Expand Down
9 changes: 6 additions & 3 deletions src/integrator_accessors.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# the following are setup per how integrators are implemented in OrdinaryDiffEq and
# StochasticDiffEq and provide dispatch points that JumpProcesses and others can use.

get_tstops(integ::DEIntegrator) =
function get_tstops(integ::DEIntegrator)
error("get_tstops not implemented for integrators of type $(nameof(typeof(integ)))")
get_tstops_array(integ::DEIntegrator) =
end
function get_tstops_array(integ::DEIntegrator)
error("get_tstops_array not implemented for integrators of type $(nameof(typeof(integ)))")
get_tstops_max(integ::DEIntegrator) =
end
function get_tstops_max(integ::DEIntegrator)
error("get_tstops_max not implemented for integrators of type $(nameof(typeof(integ)))")
end
9 changes: 6 additions & 3 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,8 @@ function get_concrete_problem(prob, isadapt; kwargs...)
tspan_promote = promote_tspan(u0_promote, p, tspan, prob, kwargs)
f_promote = promote_f(prob.f, Val(SciMLBase.specialization(prob.f)), u0_promote, p,
tspan_promote[1])
if isconcreteu0(prob, tspan[1], kwargs) && typeof(u0_promote) === typeof(prob.u0) &&
if isconcreteu0(prob, tspan[1], kwargs) && prob.u0 === u0 &&
typeof(u0_promote) === typeof(prob.u0) &&
prob.tspan == tspan && typeof(prob.tspan) === typeof(tspan_promote) &&
p === prob.p && f_promote === prob.f
return prob
Expand Down Expand Up @@ -1388,7 +1389,8 @@ function __solve(
kwargs...)
if second_time
throw(NoDefaultAlgorithmError())
elseif length(args) > 0 && !(first(args) isa Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm})
elseif length(args) > 0 && !(first(args) isa
Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm})
throw(NonSolverError())
else
__solve(prob, nothing, args...; default_set = false, second_time = true, kwargs...)
Expand All @@ -1399,7 +1401,8 @@ function __init(prob::AbstractDEProblem, args...; default_set = false, second_ti
kwargs...)
if second_time
throw(NoDefaultAlgorithmError())
elseif length(args) > 0 && !(first(args) isa Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm})
elseif length(args) > 0 && !(first(args) isa
Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm})
throw(NonSolverError())
else
__init(prob, nothing, args...; default_set = false, second_time = true, kwargs...)
Expand Down
4 changes: 4 additions & 0 deletions test/downstream/prob_kwargs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ prob = ODEProblem(lorenz, u0, tspan, alg = Tsit5())
@test_nowarn sol = solve(prob, reltol = 1e-6)
sol = solve(prob, reltol = 1e-6)
@test sol.alg isa Tsit5

new_u0 = rand(3)
sol = solve(prob, u0 = new_u0)
@test sol.prob.u0 === new_u0
6 changes: 5 additions & 1 deletion test/downstream/tables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ sol1 = solve(prob, Euler(); dt = 1 // 2^(4));
df = DataFrame(sol1)
@test names(df) == ["timestamp", "value1", "value2", "value3", "value4"]

prob = ODEProblem(ODEFunction(f_2dlinear, sys = SymbolicIndexingInterface.SymbolCache([:a, :b, :c, :d], [], :t)), rand(2, 2), (0.0, 1.0));
prob = ODEProblem(
ODEFunction(
f_2dlinear, sys = SymbolicIndexingInterface.SymbolCache([:a, :b, :c, :d], [], :t)),
rand(2, 2),
(0.0, 1.0));
sol2 = solve(prob, Euler(); dt = 1 // 2^(4));
df = DataFrame(sol2)
@test names(df) == ["timestamp", "a", "b", "c", "d"]
Loading