Skip to content

Commit 62f0556

Browse files
Merge pull request #2641 from oscardssmith/os/fix-interpolation-type-stability
fix interpolation type stability
2 parents 505c73d + 1472a27 commit 62f0556

File tree

2 files changed

+17
-17
lines changed

2 files changed

+17
-17
lines changed

lib/OrdinaryDiffEqCore/src/dense/generic_dense.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,8 @@ ode_interpolation(tvals,ts,timeseries,ks)
553553
Get the value at tvals where the solution is known at the
554554
times ts (sorted), with values timeseries and derivatives ks
555555
"""
556-
function ode_interpolation(tvals, id::I, idxs, deriv::D, p,
557-
continuity::Symbol = :left) where {I, D}
556+
function ode_interpolation(tvals, id::I, idxs, ::Type{deriv}, p,
557+
continuity::Symbol = :left) where {I, deriv}
558558
@unpack ts, timeseries, ks, f, cache, differential_vars = id
559559
@inbounds tdir = sign(ts[end] - ts[1])
560560
idx = sortperm(tvals, rev = tdir < 0)
@@ -591,8 +591,8 @@ ode_interpolation(tvals,ts,timeseries,ks)
591591
Get the value at tvals where the solution is known at the
592592
times ts (sorted), with values timeseries and derivatives ks
593593
"""
594-
function ode_interpolation!(vals, tvals, id::I, idxs, deriv::D, p,
595-
continuity::Symbol = :left) where {I, D}
594+
function ode_interpolation!(vals, tvals, id::I, idxs, ::Type{deriv}, p,
595+
continuity::Symbol = :left) where {I, deriv}
596596
@unpack ts, timeseries, ks, f, cache, differential_vars = id
597597
@inbounds tdir = sign(ts[end] - ts[1])
598598
idx = sortperm(tvals, rev = tdir < 0)
@@ -756,8 +756,8 @@ ode_interpolation(tval::Number,ts,timeseries,ks)
756756
Get the value at tval where the solution is known at the
757757
times ts (sorted), with values timeseries and derivatives ks
758758
"""
759-
function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p,
760-
continuity::Symbol = :left) where {I, D}
759+
function ode_interpolation(tval::Number, id::I, idxs, ::Type{deriv}, p,
760+
continuity::Symbol = :left) where {I, deriv}
761761
@unpack ts, timeseries, ks, f, cache, differential_vars = id
762762
@inbounds tdir = sign(ts[end] - ts[1])
763763

@@ -840,8 +840,8 @@ ode_interpolation!(out,tval::Number,ts,timeseries,ks)
840840
Get the value at tval where the solution is known at the
841841
times ts (sorted), with values timeseries and derivatives ks
842842
"""
843-
function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p,
844-
continuity::Symbol = :left) where {I, D}
843+
function ode_interpolation!(out, tval::Number, id::I, idxs, ::Type{deriv}, p,
844+
continuity::Symbol = :left) where {I, deriv}
845845
@unpack ts, timeseries, ks, f, cache, differential_vars = id
846846
@inbounds tdir = sign(ts[end] - ts[1])
847847

test/interface/inplace_interpolation.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,23 @@ out_VMF = vecarrzero(ntt, size(prob_ode_2Dlinear.u0)) # Vector{Matrix{Float64}
1818
sol_ODE = solve(prob_ode_linear, alg; kwargs...)
1919
sol_ODE_2D = solve(prob_ode_2Dlinear, alg; kwargs...)
2020

21-
sol_ODE_interp = sol_ODE(tt)
22-
sol_ODE_2D_interp = sol_ODE_2D(tt)
21+
sol_ODE_interp = @inferred sol_ODE(tt)
22+
sol_ODE_2D_interp = @inferred sol_ODE_2D(tt)
2323

2424
@testset "1D" begin
2525
@test_throws MethodError sol_ODE(out_VF, tt; idxs = 1:1)
26-
@test sol_ODE(out_VF, tt) isa Vector{Float64}
27-
@test sol_ODE(out_VVF_1, tt) isa Vector{Vector{Float64}}
26+
@inferred Vector{Float64} sol_ODE(out_VF, tt)
27+
@inferred Vector{Vector{Float64}} sol_ODE(out_VVF_1, tt)
2828
@test sol_ODE_interp.u out_VF
2929
end
3030

3131
@testset "2D" begin
3232
@test_throws MethodError sol_ODE_2D(out_VF, tt; idxs = 3:3)
33-
@test sol_ODE_2D(out_VF, tt; idxs = 3) isa Vector{Float64}
34-
@test sol_ODE_2D(out_VVF_1, tt; idxs = 3) isa Vector{Vector{Float64}}
35-
@test sol_ODE_2D(out_VVF_1, tt; idxs = 3:3) isa Vector{Vector{Float64}}
36-
@test sol_ODE_2D(out_VVF_2, tt; idxs = 2:3) isa Vector{Vector{Float64}}
37-
@test sol_ODE_2D(out_VMF, tt) isa Vector{Matrix{Float64}}
33+
@inferred Vector{Float64} sol_ODE_2D(out_VF, tt; idxs = 3)
34+
@inferred Vector{Vector{Float64}} sol_ODE_2D(out_VVF_1, tt; idxs = 3)
35+
@inferred Vector{Vector{Float64}} sol_ODE_2D(out_VVF_1, tt; idxs = 3:3)
36+
@inferred Vector{Vector{Float64}} sol_ODE_2D(out_VVF_2, tt; idxs = 2:3)
37+
@inferred Vector{Matrix{Float64}} sol_ODE_2D(out_VMF, tt)
3838
@test sol_ODE_2D_interp.u out_VMF
3939
end
4040
end

0 commit comments

Comments
 (0)