Skip to content

Commit dd9f8cf

Browse files
Merge pull request #2201 from SciML/dprkn6_interp
Fix and test DPRKN interpolation with idxs
2 parents 1f0e158 + 2e7cb8c commit dd9f8cf

File tree

3 files changed

+78
-25
lines changed

3 files changed

+78
-25
lines changed

src/dense/interpolants.jl

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2947,6 +2947,28 @@ end
29472947
b4Θ * k4[idxs] + b5Θ * k5[idxs] + b6Θ * k6[idxs])))
29482948
end
29492949

2950+
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
2951+
cache::Union{DPRKN6ConstantCache, DPRKN6Cache}, idxs::Number,
2952+
T::Type{Val{0}}, differential_vars::Nothing)
2953+
@dprkn6pre0
2954+
halfsize = length(y₀) ÷ 2
2955+
if idxs <= halfsize
2956+
duprev[idxs] +
2957+
dt * Θ *
2958+
(bp1Θ * k1[idxs] + bp3Θ * k3[idxs] +
2959+
bp4Θ * k4[idxs] + bp5Θ * k5[idxs] + bp6Θ * k6[idxs])
2960+
else
2961+
idxs = idxs - halfsize
2962+
uprev[idxs] +
2963+
dt * Θ *
2964+
(duprev[idxs] +
2965+
dt * Θ *
2966+
(b1Θ * k1[idxs] +
2967+
b3Θ * k3[idxs] +
2968+
b4Θ * k4[idxs] + b5Θ * k5[idxs] + b6Θ * k6[idxs]))
2969+
end
2970+
end
2971+
29502972
@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
29512973
cache::Union{DPRKN6ConstantCache, DPRKN6Cache},
29522974
idxs::Nothing, T::Type{Val{0}}, differential_vars::Nothing)
@@ -2976,25 +2998,28 @@ end
29762998
cache::Union{DPRKN6ConstantCache, DPRKN6Cache}, idxs,
29772999
T::Type{Val{0}}, differential_vars::Nothing)
29783000
@dprkn6pre0
2979-
@inbounds @.. broadcast=false out.x[2]=uprev[idxs] +
2980-
dt * Θ *
2981-
(duprev[idxs] +
2982-
dt * Θ *
2983-
(b1Θ * k1[idxs] +
2984-
b3Θ * k3[idxs] +
2985-
b4Θ * k4[idxs] + b5Θ * k5[idxs] +
2986-
b6Θ * k6[idxs]))
2987-
@inbounds @.. broadcast=false out.x[1]=duprev[idxs] +
2988-
dt * Θ *
2989-
(bp1Θ * k1[idxs] + bp3Θ * k3[idxs] +
2990-
bp4Θ * k4[idxs] + bp5Θ * k5[idxs] +
2991-
bp6Θ * k6[idxs])
2992-
#for (j,i) in enumerate(idxs)
2993-
# out.x[2][j] = uprev[i] + dt*Θ*(duprev[i] + dt*Θ*(b1Θ*k1[i] +
2994-
# b3Θ*k3[i] +
2995-
# b4Θ*k4[i] + b5Θ*k5[i] + b6Θ*k6[i]))
2996-
# out.x[1][j] = duprev[i] + dt*Θ*(bp1Θ*k1[i] + bp3Θ*k3[i] +
2997-
# bp4Θ*k4[i] + bp5Θ*k5[i] + bp6Θ*k6[i])
2998-
#end
3001+
halfsize = length(y₀) ÷ 2
3002+
isfirsthalf = idxs .<= halfsize
3003+
secondhalf = idxs .> halfsize
3004+
firstidxs = idxs[isfirsthalf]
3005+
secondidxs_shifted = idxs[secondhalf]
3006+
secondidxs = secondidxs_shifted .- halfsize
3007+
3008+
@views @.. broadcast=false out[secondhalf]=uprev[secondidxs] +
3009+
dt * Θ *
3010+
(duprev[secondidxs] +
3011+
dt * Θ *
3012+
(b1Θ * k1[secondidxs] +
3013+
b3Θ * k3[secondidxs] +
3014+
b4Θ * k4[secondidxs] +
3015+
b5Θ * k5[secondidxs] +
3016+
b6Θ * k6[secondidxs]))
3017+
@views @.. broadcast=false out[isfirsthalf]=duprev[firstidxs] +
3018+
dt * Θ *
3019+
(bp1Θ * k1[firstidxs] +
3020+
bp3Θ * k3[firstidxs] +
3021+
bp4Θ * k4[firstidxs] +
3022+
bp5Θ * k5[firstidxs] +
3023+
bp6Θ * k6[firstidxs])
29993024
out
30003025
end

src/nlsolve/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,8 @@ function build_nlsolver(
308308
end
309309
prob = NonlinearProblem(NonlinearFunction(nlf), copy(ztmp), nlp_params)
310310
cache = init(prob, nlalg.alg)
311-
nlcache = NonlinearSolveCache(nothing, tstep, nothing, nothing, invγdt, prob, cache)
311+
nlcache = NonlinearSolveCache(
312+
nothing, tstep, nothing, nothing, invγdt, prob, cache)
312313
else
313314
nlcache = NLNewtonConstantCache(tstep, J, W, true, true, true, tType(dt), uf,
314315
invγdt, tType(nlalg.new_W_dt_cutoff), t)

test/interface/interpolation_output_types.jl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ res1 = solve(prob, Vern8(), dt = 1 / 10, saveat = 1 / 10)
1717
res3 = solve(prob, CalvoSanz4(), dt = 1 / 10, saveat = 1 / 10)
1818

1919
sol = solve(prob, CalvoSanz4(), dt = 1 / 10)
20-
@test sol(0.32) isa OrdinaryDiffEq.ArrayPartition
21-
@test sol(0.32, Val{1}) isa OrdinaryDiffEq.ArrayPartition
22-
@test sol(0.32, Val{2}) isa OrdinaryDiffEq.ArrayPartition
23-
@test sol(0.32, Val{3}) isa OrdinaryDiffEq.ArrayPartition
20+
@test sol(0.32) isa RecursiveArrayTools.ArrayPartition
21+
@test sol(0.32, Val{1}) isa RecursiveArrayTools.ArrayPartition
22+
@test sol(0.32, Val{2}) isa RecursiveArrayTools.ArrayPartition
23+
@test sol(0.32, Val{3}) isa RecursiveArrayTools.ArrayPartition
2424

2525
function f(du, u, p, t)
2626
du .= u
@@ -40,3 +40,30 @@ sol(0:0.1:100; idxs = [1, 2])
4040
@test sol(0:0.1:100) isa DiffEqArray
4141
@test length(sol(0:0.1:100)) == length(0:0.1:100)
4242
@test length(sol(0:0.1:100).u[1]) == 3
43+
44+
## Test DPRKN Interpolation
45+
46+
#Parameters
47+
ω = 1
48+
49+
#Initial Conditions
50+
x₀ = [0.0]
51+
dx₀ =/ 2]
52+
tspan = (0.0, 2π)
53+
54+
ϕ = atan((dx₀[1] / ω) / x₀[1])
55+
A = (x₀[1]^2 + dx₀[1]^2)
56+
57+
function harmonicoscillator(ddu, du, u, ω, t)
58+
ddu .= -ω^2 * u
59+
end
60+
61+
prob = SecondOrderODEProblem(harmonicoscillator, dx₀, x₀, tspan, ω)
62+
sol = solve(prob, DPRKN6())
63+
@test sol(0.5) isa RecursiveArrayTools.ArrayPartition
64+
@test sol(0.5; idxs = 1) isa Number
65+
@test sol(0.5; idxs = [1]) isa Vector
66+
@test sol(0.5; idxs = [1, 2]) isa Vector
67+
@test Vector(sol(0.5)) == sol(0.5; idxs = [1, 2])
68+
@test reverse(Vector(sol(0.5))) == sol(0.5; idxs = [2, 1])
69+
@test Vector(sol(0.5)) == [sol(0.5; idxs = 1); sol(0.5; idxs = 2)]

0 commit comments

Comments
 (0)