Skip to content

Commit 266c716

Browse files
committed
add VoA and StructArray compatibility
1 parent 1d144be commit 266c716

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

lib/OrdinaryDiffEqLowStorageRK/test/ode_low_storage_rk_tests.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,3 +1605,26 @@ end
16051605
save_start = false, alias = ODEAliasSpecifier(alias_u0 = true))
16061606
@test sol_old[end] sol_new[end]
16071607
end
1608+
1609+
@testset "VectorOfArray/StructArray compatibility" begin
1610+
using RecursiveArrayTools, StaticArrays, StructArrays
1611+
1612+
function rhs!(du_voa, u_voa, p, t)
1613+
du = parent(du_voa)
1614+
u = parent(u_voa)
1615+
du .= u
1616+
end
1617+
1618+
# StructArray storage
1619+
u = StructArray{SVector{1, Float64}}(ntuple(_ -> [1.0, 2.0], 1))
1620+
ode = ODEProblem(rhs!, VectorOfArray(u), (0, 0.7))
1621+
sol_SA = solve(ode, RDPK3SpFSAL35())
1622+
1623+
# Vector{<:SVector} storage
1624+
u = SVector{1, Float64}.([1.0, 2.0])
1625+
ode = ODEProblem(rhs!, VectorOfArray(u), (0, 0.7))
1626+
sol_SV = solve(ode, RDPK3SpFSAL35())
1627+
1628+
@test sol.SA sol_SV
1629+
@test sol_SV.stats.naccept == sol_SA.stats.naccept
1630+
end

lib/OrdinaryDiffEqSSPRK/test/ode_ssprk_tests.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,3 +491,26 @@ end
491491
sol = solve(test_problem_ssp_long, alg, dt = OrdinaryDiffEqSSPRK.ssp_coefficient(alg),
492492
dense = false)
493493
@test all(sol.u .>= 0)
494+
495+
@testset "VectorOfArray/StructArray compatibility" begin
496+
using RecursiveArrayTools, StaticArrays, StructArrays
497+
498+
function rhs!(du_voa, u_voa, p, t)
499+
du = parent(du_voa)
500+
u = parent(u_voa)
501+
du .= u
502+
end
503+
504+
# StructArray storage
505+
u = StructArray{SVector{1, Float64}}(ntuple(_ -> [1.0, 2.0], 1))
506+
ode = ODEProblem(rhs!, VectorOfArray(u), (0, 0.7))
507+
sol_SA = solve(ode, SSPRK43())
508+
509+
# Vector{<:SVector} storage
510+
u = SVector{1, Float64}.([1.0, 2.0])
511+
ode = ODEProblem(rhs!, VectorOfArray(u), (0, 0.7))
512+
sol_SV = solve(ode, SSPRK43())
513+
514+
@test sol.SA sol_SV
515+
@test sol_SV.stats.naccept == sol_SA.stats.naccept
516+
end

0 commit comments

Comments
 (0)