Skip to content

Commit c321a87

Browse files
Merge pull request #2633 from jlchan/jc/add_VoA_StructArray_tests
Check for consistency between different `VectorOfArray` parent array types
2 parents acb025e + 5191d89 commit c321a87

File tree

5 files changed

+65
-3
lines changed

5 files changed

+65
-3
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ SparseDiffTools = "2"
142142
Static = "0.8, 1"
143143
StaticArrayInterface = "1.2"
144144
StaticArrays = "1.0"
145+
StructArrays = "0.6"
145146
TruncatedStacktraces = "1.2"
146147
julia = "1.10"
147148

@@ -167,9 +168,10 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
167168
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
168169
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
169170
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
171+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
170172
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
171173
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
172174
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
173175

174176
[targets]
175-
test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "ParameterizedFunctions", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve", "RecursiveFactorization"]
177+
test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "ParameterizedFunctions", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "StructArrays", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve", "RecursiveFactorization"]

lib/OrdinaryDiffEqLowStorageRK/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Reexport = "1.2.2"
3434
SafeTestsets = "0.1.0"
3535
Static = "1.1.1"
3636
StaticArrays = "1.9.7"
37+
StructArrays = "0.6"
3738
Test = "<0.0.1, 1"
3839
julia = "1.10"
3940

@@ -42,7 +43,8 @@ DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
4243
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
4344
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4445
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
46+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
4547
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4648

4749
[targets]
48-
test = ["DiffEqDevTools", "Random", "SafeTestsets", "Test", "ODEProblemLibrary"]
50+
test = ["DiffEqDevTools", "Random", "SafeTestsets", "StructArrays", "Test", "ODEProblemLibrary"]

lib/OrdinaryDiffEqLowStorageRK/test/ode_low_storage_rk_tests.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,3 +1605,36 @@ end
16051605
save_start = false, alias = ODEAliasSpecifier(alias_u0 = true))
16061606
@test sol_old[end] sol_new[end]
16071607
end
1608+
1609+
@testset "VectorOfArray/StructArray compatibility" begin
1610+
using RecursiveArrayTools, StaticArrays, StructArrays
1611+
1612+
function rhs!(du_voa, u_voa, p, t)
1613+
du = parent(du_voa)
1614+
u = parent(u_voa)
1615+
du .= u
1616+
end
1617+
1618+
# StructArray storage
1619+
u = StructArray{SVector{1, Float64}}(ntuple(_ -> [1.0, 2.0], 1))
1620+
ode = ODEProblem(rhs!, VectorOfArray(u), (0, 0.7))
1621+
sol_SA = solve(ode, RDPK3SpFSAL35())
1622+
1623+
# Vector{<:SVector} storage
1624+
u = SVector{1, Float64}.([1.0, 2.0])
1625+
ode = ODEProblem(rhs!, VectorOfArray(u), (0, 0.7))
1626+
sol_SV = solve(ode, RDPK3SpFSAL35())
1627+
1628+
@test sol_SA sol_SV
1629+
@test sol_SV.stats.naccept == sol_SA.stats.naccept
1630+
1631+
# Plain vector
1632+
u = [1.0, 2.0]
1633+
ode = ODEProblem(rhs!, u, (0, 0.7))
1634+
sol = solve(ode, RDPK3SpFSAL35())
1635+
@test sol.stats.naccept == sol_SA.stats.naccept
1636+
@test sol.t sol_SA.t
1637+
for i in eachindex(sol_SA.u), j in eachindex(u)
1638+
@test sol.u[i][j] sol_SA.u[i][j][1]
1639+
end
1640+
end

lib/OrdinaryDiffEqSSPRK/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Reexport = "1.2.2"
3333
SafeTestsets = "0.1.0"
3434
Static = "1.1.1"
3535
StaticArrays = "1.9.7"
36+
StructArrays = "0.6"
3637
Test = "<0.0.1, 1"
3738
julia = "1.10"
3839

@@ -42,7 +43,8 @@ ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
4243
OrdinaryDiffEqLowStorageRK = "b0944070-b475-4768-8dec-fb6eb410534d"
4344
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4445
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
46+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
4547
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4648

4749
[targets]
48-
test = ["DiffEqDevTools", "Random", "SafeTestsets", "Test", "ODEProblemLibrary", "OrdinaryDiffEqLowStorageRK"]
50+
test = ["DiffEqDevTools", "Random", "SafeTestsets", "StructArrays", "Test", "ODEProblemLibrary", "OrdinaryDiffEqLowStorageRK"]

lib/OrdinaryDiffEqSSPRK/test/ode_ssprk_tests.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,3 +491,26 @@ end
491491
sol = solve(test_problem_ssp_long, alg, dt = OrdinaryDiffEqSSPRK.ssp_coefficient(alg),
492492
dense = false)
493493
@test all(sol.u .>= 0)
494+
495+
@testset "VectorOfArray/StructArray compatibility" begin
496+
using RecursiveArrayTools, StaticArrays, StructArrays
497+
498+
function rhs!(du_voa, u_voa, p, t)
499+
du = parent(du_voa)
500+
u = parent(u_voa)
501+
du .= u
502+
end
503+
504+
# StructArray storage
505+
u = StructArray{SVector{1, Float64}}(ntuple(_ -> [1.0, 2.0], 1))
506+
ode = ODEProblem(rhs!, VectorOfArray(u), (0, 0.7))
507+
sol_SA = solve(ode, SSPRK43())
508+
509+
# Vector{<:SVector} storage
510+
u = SVector{1, Float64}.([1.0, 2.0])
511+
ode = ODEProblem(rhs!, VectorOfArray(u), (0, 0.7))
512+
sol_SV = solve(ode, SSPRK43())
513+
514+
@test sol_SA sol_SV
515+
@test sol_SV.stats.naccept == sol_SA.stats.naccept
516+
end

0 commit comments

Comments
 (0)