Skip to content

Commit 0b533b0

Browse files
Fix and test DPRKN interpolation with idxs
1 parent 1f0e158 commit 0b533b0

File tree

3 files changed

+71
-21
lines changed

3 files changed

+71
-21
lines changed

src/dense/generic_dense.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,7 @@ function ode_interpolant(Θ, dt, y₀, y₁, k, cache::OrdinaryDiffEqMutableCach
591591
T::Type{Val{TI}}, differential_vars) where {TI}
592592
if idxs isa Number || y₀ isa Union{Number, SArray}
593593
# typeof(y₀) can be these if saveidxs gives a single value
594+
@show Main.@which _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars)
594595
_ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars)
595596
elseif idxs isa Nothing
596597
if y₁ isa Array{<:Number}

src/dense/interpolants.jl

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2947,6 +2947,28 @@ end
29472947
b4Θ * k4[idxs] + b5Θ * k5[idxs] + b6Θ * k6[idxs])))
29482948
end
29492949

2950+
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
2951+
cache::Union{DPRKN6ConstantCache, DPRKN6Cache}, idxs::Number,
2952+
T::Type{Val{0}}, differential_vars::Nothing)
2953+
@dprkn6pre0
2954+
halfsize = length(y₀) ÷ 2
2955+
if idxs <= halfsize
2956+
duprev[idxs] +
2957+
dt * Θ *
2958+
(bp1Θ * k1[idxs] + bp3Θ * k3[idxs] +
2959+
bp4Θ * k4[idxs] + bp5Θ * k5[idxs] + bp6Θ * k6[idxs])
2960+
else
2961+
idxs = idxs - halfsize
2962+
uprev[idxs] +
2963+
dt * Θ *
2964+
(duprev[idxs] +
2965+
dt * Θ *
2966+
(b1Θ * k1[idxs] +
2967+
b3Θ * k3[idxs] +
2968+
b4Θ * k4[idxs] + b5Θ * k5[idxs] + b6Θ * k6[idxs]))
2969+
end
2970+
end
2971+
29502972
@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
29512973
cache::Union{DPRKN6ConstantCache, DPRKN6Cache},
29522974
idxs::Nothing, T::Type{Val{0}}, differential_vars::Nothing)
@@ -2976,25 +2998,25 @@ end
29762998
cache::Union{DPRKN6ConstantCache, DPRKN6Cache}, idxs,
29772999
T::Type{Val{0}}, differential_vars::Nothing)
29783000
@dprkn6pre0
2979-
@inbounds @.. broadcast=false out.x[2]=uprev[idxs] +
3001+
halfsize = length(y₀) ÷ 2
3002+
isfirsthalf = idxs .<= halfsize
3003+
secondhalf = idxs .> halfsize
3004+
firstidxs = idxs[isfirsthalf]
3005+
secondidxs_shifted = idxs[secondhalf]
3006+
secondidxs = secondidxs_shifted .- halfsize
3007+
3008+
@views @.. broadcast=false out[secondhalf]=uprev[secondidxs] +
29803009
dt * Θ *
2981-
(duprev[idxs] +
3010+
(duprev[secondidxs] +
29823011
dt * Θ *
2983-
(b1Θ * k1[idxs] +
2984-
b3Θ * k3[idxs] +
2985-
b4Θ * k4[idxs] + b5Θ * k5[idxs] +
2986-
b6Θ * k6[idxs]))
2987-
@inbounds @.. broadcast=false out.x[1]=duprev[idxs] +
3012+
(b1Θ * k1[secondidxs] +
3013+
b3Θ * k3[secondidxs] +
3014+
b4Θ * k4[secondidxs] + b5Θ * k5[secondidxs] +
3015+
b6Θ * k6[secondidxs]))
3016+
@views @.. broadcast=false out[isfirsthalf]=duprev[firstidxs] +
29883017
dt * Θ *
2989-
(bp1Θ * k1[idxs] + bp3Θ * k3[idxs] +
2990-
bp4Θ * k4[idxs] + bp5Θ * k5[idxs] +
2991-
bp6Θ * k6[idxs])
2992-
#for (j,i) in enumerate(idxs)
2993-
# out.x[2][j] = uprev[i] + dt*Θ*(duprev[i] + dt*Θ*(b1Θ*k1[i] +
2994-
# b3Θ*k3[i] +
2995-
# b4Θ*k4[i] + b5Θ*k5[i] + b6Θ*k6[i]))
2996-
# out.x[1][j] = duprev[i] + dt*Θ*(bp1Θ*k1[i] + bp3Θ*k3[i] +
2997-
# bp4Θ*k4[i] + bp5Θ*k5[i] + bp6Θ*k6[i])
2998-
#end
3018+
(bp1Θ * k1[firstidxs] + bp3Θ * k3[firstidxs] +
3019+
bp4Θ * k4[firstidxs] + bp5Θ * k5[firstidxs] +
3020+
bp6Θ * k6[firstidxs])
29993021
out
30003022
end

test/interface/interpolation_output_types.jl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ res1 = solve(prob, Vern8(), dt = 1 / 10, saveat = 1 / 10)
1717
res3 = solve(prob, CalvoSanz4(), dt = 1 / 10, saveat = 1 / 10)
1818

1919
sol = solve(prob, CalvoSanz4(), dt = 1 / 10)
20-
@test sol(0.32) isa OrdinaryDiffEq.ArrayPartition
21-
@test sol(0.32, Val{1}) isa OrdinaryDiffEq.ArrayPartition
22-
@test sol(0.32, Val{2}) isa OrdinaryDiffEq.ArrayPartition
23-
@test sol(0.32, Val{3}) isa OrdinaryDiffEq.ArrayPartition
20+
@test sol(0.32) isa RecursiveArrayTools.ArrayPartition
21+
@test sol(0.32, Val{1}) isa RecursiveArrayTools.ArrayPartition
22+
@test sol(0.32, Val{2}) isa RecursiveArrayTools.ArrayPartition
23+
@test sol(0.32, Val{3}) isa RecursiveArrayTools.ArrayPartition
2424

2525
function f(du, u, p, t)
2626
du .= u
@@ -40,3 +40,30 @@ sol(0:0.1:100; idxs = [1, 2])
4040
@test sol(0:0.1:100) isa DiffEqArray
4141
@test length(sol(0:0.1:100)) == length(0:0.1:100)
4242
@test length(sol(0:0.1:100).u[1]) == 3
43+
44+
## Test DPRKN Interpolation
45+
46+
#Parameters
47+
ω = 1
48+
49+
#Initial Conditions
50+
x₀ = [0.0]
51+
dx₀ =/ 2]
52+
tspan = (0.0, 2π)
53+
54+
ϕ = atan((dx₀[1] / ω) / x₀[1])
55+
A = (x₀[1]^2 + dx₀[1]^2)
56+
57+
function harmonicoscillator(ddu, du, u, ω, t)
58+
ddu .= -ω^2 * u
59+
end
60+
61+
prob = SecondOrderODEProblem(harmonicoscillator, dx₀, x₀, tspan, ω)
62+
sol = solve(prob, DPRKN6())
63+
@test sol(0.5) isa RecursiveArrayTools.ArrayPartition
64+
@test sol(0.5; idxs = 1) isa Number
65+
@test sol(0.5; idxs = [1]) isa Vector
66+
@test sol(0.5; idxs = [1,2]) isa Vector
67+
@test Vector(sol(0.5)) == sol(0.5; idxs = [1,2])
68+
@test reverse(Vector(sol(0.5))) == sol(0.5; idxs = [2,1])
69+
@test Vector(sol(0.5)) == [sol(0.5; idxs = 1); sol(0.5; idxs = 2)]

0 commit comments

Comments
 (0)