diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl index 6eb3e1c7a..37d82885c 100644 --- a/src/forwarddiff.jl +++ b/src/forwarddiff.jl @@ -347,7 +347,10 @@ function anyeltypedual(x::NamedTuple, ::Type{Val{counter}} = Val{0}) where {coun anyeltypedual(values(x)) end -DiffEqBase.anyeltypedual(f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}}) where {counter} = Any +function DiffEqBase.anyeltypedual( + f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}}) where {counter} + Any +end @inline promote_u0(::Nothing, p, t0) = nothing diff --git a/src/integrator_accessors.jl b/src/integrator_accessors.jl index 3a7550918..b89dd3d1b 100644 --- a/src/integrator_accessors.jl +++ b/src/integrator_accessors.jl @@ -1,9 +1,12 @@ # the following are setup per how integrators are implemented in OrdinaryDiffEq and # StochasticDiffEq and provide dispatch points that JumpProcesses and others can use. -get_tstops(integ::DEIntegrator) = +function get_tstops(integ::DEIntegrator) error("get_tstops not implemented for integrators of type $(nameof(typeof(integ)))") -get_tstops_array(integ::DEIntegrator) = +end +function get_tstops_array(integ::DEIntegrator) error("get_tstops_array not implemented for integrators of type $(nameof(typeof(integ)))") -get_tstops_max(integ::DEIntegrator) = +end +function get_tstops_max(integ::DEIntegrator) error("get_tstops_max not implemented for integrators of type $(nameof(typeof(integ)))") +end diff --git a/src/solve.jl b/src/solve.jl index 59996ea6e..c16a5c6cb 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1172,7 +1172,8 @@ function get_concrete_problem(prob, isadapt; kwargs...) tspan_promote = promote_tspan(u0_promote, p, tspan, prob, kwargs) f_promote = promote_f(prob.f, Val(SciMLBase.specialization(prob.f)), u0_promote, p, tspan_promote[1]) - if isconcreteu0(prob, tspan[1], kwargs) && typeof(u0_promote) === typeof(prob.u0) && + if isconcreteu0(prob, tspan[1], kwargs) && prob.u0 === u0 && + typeof(u0_promote) === typeof(prob.u0) && prob.tspan == tspan && typeof(prob.tspan) === typeof(tspan_promote) && p === prob.p && f_promote === prob.f return prob @@ -1388,7 +1389,8 @@ function __solve( kwargs...) if second_time throw(NoDefaultAlgorithmError()) - elseif length(args) > 0 && !(first(args) isa Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm}) + elseif length(args) > 0 && !(first(args) isa + Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm}) throw(NonSolverError()) else __solve(prob, nothing, args...; default_set = false, second_time = true, kwargs...) @@ -1399,7 +1401,8 @@ function __init(prob::AbstractDEProblem, args...; default_set = false, second_ti kwargs...) if second_time throw(NoDefaultAlgorithmError()) - elseif length(args) > 0 && !(first(args) isa Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm}) + elseif length(args) > 0 && !(first(args) isa + Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm}) throw(NonSolverError()) else __init(prob, nothing, args...; default_set = false, second_time = true, kwargs...) diff --git a/test/downstream/prob_kwargs.jl b/test/downstream/prob_kwargs.jl index f4d7e2c1f..017466e0c 100644 --- a/test/downstream/prob_kwargs.jl +++ b/test/downstream/prob_kwargs.jl @@ -10,3 +10,7 @@ prob = ODEProblem(lorenz, u0, tspan, alg = Tsit5()) @test_nowarn sol = solve(prob, reltol = 1e-6) sol = solve(prob, reltol = 1e-6) @test sol.alg isa Tsit5 + +new_u0 = rand(3) +sol = solve(prob, u0 = new_u0) +@test sol.prob.u0 === new_u0 diff --git a/test/downstream/tables.jl b/test/downstream/tables.jl index 7279eb621..84f2ebec4 100644 --- a/test/downstream/tables.jl +++ b/test/downstream/tables.jl @@ -5,7 +5,11 @@ sol1 = solve(prob, Euler(); dt = 1 // 2^(4)); df = DataFrame(sol1) @test names(df) == ["timestamp", "value1", "value2", "value3", "value4"] -prob = ODEProblem(ODEFunction(f_2dlinear, sys = SymbolicIndexingInterface.SymbolCache([:a, :b, :c, :d], [], :t)), rand(2, 2), (0.0, 1.0)); +prob = ODEProblem( + ODEFunction( + f_2dlinear, sys = SymbolicIndexingInterface.SymbolCache([:a, :b, :c, :d], [], :t)), + rand(2, 2), + (0.0, 1.0)); sol2 = solve(prob, Euler(); dt = 1 // 2^(4)); df = DataFrame(sol2) @test names(df) == ["timestamp", "a", "b", "c", "d"]