Skip to content

Commit d3fe8c0

Browse files
Simplify ensemble indexing
Should fix the segfaults on SciMLSensitivity, and there are tests to ensure it does not regress
1 parent 4be9585 commit d3fe8c0

File tree

3 files changed

+15
-19
lines changed

3 files changed

+15
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ PyCall = "1.96"
7878
PythonCall = "0.9.15"
7979
RCall = "0.14.0"
8080
RecipesBase = "1.3.4"
81-
RecursiveArrayTools = "3.26.0"
81+
RecursiveArrayTools = "3.27.2"
8282
Reexport = "1"
8383
RuntimeGeneratedFunctions = "0.5.12"
8484
SciMLOperators = "0.3.7"

src/ensemble/ensemble_solutions.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -226,20 +226,6 @@ end
226226
end
227227
end
228228

229-
Base.@propagate_inbounds function Base.getindex(
230-
x::AbstractEnsembleSolution, s::Integer, i::Integer)
231-
return x.u[s].u[i]
232-
end
233-
234-
Base.@propagate_inbounds function Base.getindex(
235-
x::AbstractEnsembleSolution, s::Integer, i2::Integer, i3::Integer, idxs::Integer...)
236-
return x.u[s][i2, i3, idxs...]
237-
end
238-
239-
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
240-
return [xi[s] for xi in x.u]
241-
end
242-
243229
function (sol::AbstractEnsembleSolution)(args...; kwargs...)
244230
[s(args...; kwargs...) for s in sol]
245231
end

test/downstream/ensemble_diffeq.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
1-
using OrdinaryDiffEq
1+
using OrdinaryDiffEq, Test
22

3-
prob = ODEProblem((u, p, t) -> 1.01u, 0.5, (0.0, 1.0))
3+
A = [1 2
4+
3 4]
5+
prob = ODEProblem((u, p, t) -> A*u, ones(2,2), (0.0, 1.0))
46
function prob_func(prob, i, repeat)
5-
remake(prob, u0 = rand() * prob.u0)
7+
remake(prob, u0 = i * prob.u0)
68
end
79
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
8-
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10)
10+
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10, saveat=0.01)
911
@test sim isa EnsembleSolution
12+
@test size(sim[1,:,:,:]) == (2,101,10)
13+
@test size(sim[:,1,:,:]) == (2,101,10)
14+
@test size(sim[:,:,1,:]) == (2,2,10)
15+
@test size(sim[:,:,:,1]) == (2,2,101)
16+
@test Array(sim)[1,:,:,:] == sim[1,:,:,:]
17+
@test Array(sim)[:,1,:,:] == sim[:,1,:,:]
18+
@test Array(sim)[:,:,1,:] == sim[:,:,1,:]
19+
@test Array(sim)[:,:,:,1] == sim[:,:,:,1]

0 commit comments

Comments
 (0)