From d3fe8c01e1827e63c8b2eefe48d1979a09fdd324 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 30 Oct 2024 07:34:08 -0100 Subject: [PATCH 1/2] Simplify ensemble indexing Should fix the segfaults on SciMLSensitivity, and there are tests to ensure it does not regress --- Project.toml | 2 +- src/ensemble/ensemble_solutions.jl | 14 -------------- test/downstream/ensemble_diffeq.jl | 18 ++++++++++++++---- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index c7057a507..a0727ed97 100644 --- a/Project.toml +++ b/Project.toml @@ -78,7 +78,7 @@ PyCall = "1.96" PythonCall = "0.9.15" RCall = "0.14.0" RecipesBase = "1.3.4" -RecursiveArrayTools = "3.26.0" +RecursiveArrayTools = "3.27.2" Reexport = "1" RuntimeGeneratedFunctions = "0.5.12" SciMLOperators = "0.3.7" diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index 73969102d..9c9b60502 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -226,20 +226,6 @@ end end end -Base.@propagate_inbounds function Base.getindex( - x::AbstractEnsembleSolution, s::Integer, i::Integer) - return x.u[s].u[i] -end - -Base.@propagate_inbounds function Base.getindex( - x::AbstractEnsembleSolution, s::Integer, i2::Integer, i3::Integer, idxs::Integer...) - return x.u[s][i2, i3, idxs...] -end - -Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon) - return [xi[s] for xi in x.u] -end - function (sol::AbstractEnsembleSolution)(args...; kwargs...) [s(args...; kwargs...) for s in sol] end diff --git a/test/downstream/ensemble_diffeq.jl b/test/downstream/ensemble_diffeq.jl index ea792a8ee..2985cc72b 100644 --- a/test/downstream/ensemble_diffeq.jl +++ b/test/downstream/ensemble_diffeq.jl @@ -1,9 +1,19 @@ -using OrdinaryDiffEq +using OrdinaryDiffEq, Test -prob = ODEProblem((u, p, t) -> 1.01u, 0.5, (0.0, 1.0)) +A = [1 2 + 3 4] +prob = ODEProblem((u, p, t) -> A*u, ones(2,2), (0.0, 1.0)) function prob_func(prob, i, repeat) - remake(prob, u0 = rand() * prob.u0) + remake(prob, u0 = i * prob.u0) end ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) -sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10) +sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10, saveat=0.01) @test sim isa EnsembleSolution +@test size(sim[1,:,:,:]) == (2,101,10) +@test size(sim[:,1,:,:]) == (2,101,10) +@test size(sim[:,:,1,:]) == (2,2,10) +@test size(sim[:,:,:,1]) == (2,2,101) +@test Array(sim)[1,:,:,:] == sim[1,:,:,:] +@test Array(sim)[:,1,:,:] == sim[:,1,:,:] +@test Array(sim)[:,:,1,:] == sim[:,:,1,:] +@test Array(sim)[:,:,:,1] == sim[:,:,:,1] \ No newline at end of file From 5e315dab48ef2d567399b9086c7b9a643780526a Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 30 Oct 2024 09:56:32 -0100 Subject: [PATCH 2/2] Fix the tests --- test/downstream/ensemble_multi_prob.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/downstream/ensemble_multi_prob.jl b/test/downstream/ensemble_multi_prob.jl index 6277c3fb8..92f20065b 100644 --- a/test/downstream/ensemble_multi_prob.jl +++ b/test/downstream/ensemble_multi_prob.jl @@ -17,10 +17,9 @@ prob3 = ODEProblem(sys3, [3.0, 3.0], (0.0, 1.0)) ensemble_prob = EnsembleProblem([prob1, prob2, prob3]) sol = solve(ensemble_prob, Tsit5(), EnsembleThreads()) for i in 1:3 - @test sol[x, :][i] == sol.u[i][x] - @test sol[y, :][i] == sol.u[i][y] + @test sol[1,:,i] == sol.u[i][x] + @test sol[2,:,i] == sol.u[i][y] end # Ensemble is a recursive array -@test only.(sol(0.0, idxs = [x])) == sol[1, 1, :] == first.(sol[x, :]) -# TODO: fix the interpolation -@test only.(sol(1.0, idxs = [x])) ≈ last.(sol[x, :]) +@test only.(sol(0.0, idxs = [x])) == sol[1, 1, :] +@test only.(sol(1.0, idxs = [x])) ≈ [sol[i][1, end] for i in 1:3] \ No newline at end of file