Skip to content

Commit 1180779

Browse files
Merge pull request #839 from SciML/ensemble_indexing
Simplify ensemble indexing
2 parents c341819 + 5e315da commit 1180779

File tree

4 files changed

+19
-24
lines changed

4 files changed

+19
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ PyCall = "1.96"
7878
PythonCall = "0.9.15"
7979
RCall = "0.14.0"
8080
RecipesBase = "1.3.4"
81-
RecursiveArrayTools = "3.26.0"
81+
RecursiveArrayTools = "3.27.2"
8282
Reexport = "1"
8383
RuntimeGeneratedFunctions = "0.5.12"
8484
SciMLOperators = "0.3.7"

src/ensemble/ensemble_solutions.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -226,20 +226,6 @@ end
226226
end
227227
end
228228

229-
Base.@propagate_inbounds function Base.getindex(
230-
x::AbstractEnsembleSolution, s::Integer, i::Integer)
231-
return x.u[s].u[i]
232-
end
233-
234-
Base.@propagate_inbounds function Base.getindex(
235-
x::AbstractEnsembleSolution, s::Integer, i2::Integer, i3::Integer, idxs::Integer...)
236-
return x.u[s][i2, i3, idxs...]
237-
end
238-
239-
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
240-
return [xi[s] for xi in x.u]
241-
end
242-
243229
function (sol::AbstractEnsembleSolution)(args...; kwargs...)
244230
[s(args...; kwargs...) for s in sol]
245231
end

test/downstream/ensemble_diffeq.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
1-
using OrdinaryDiffEq
1+
using OrdinaryDiffEq, Test
22

3-
prob = ODEProblem((u, p, t) -> 1.01u, 0.5, (0.0, 1.0))
3+
A = [1 2
4+
3 4]
5+
prob = ODEProblem((u, p, t) -> A*u, ones(2,2), (0.0, 1.0))
46
function prob_func(prob, i, repeat)
5-
remake(prob, u0 = rand() * prob.u0)
7+
remake(prob, u0 = i * prob.u0)
68
end
79
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
8-
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10)
10+
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10, saveat=0.01)
911
@test sim isa EnsembleSolution
12+
@test size(sim[1,:,:,:]) == (2,101,10)
13+
@test size(sim[:,1,:,:]) == (2,101,10)
14+
@test size(sim[:,:,1,:]) == (2,2,10)
15+
@test size(sim[:,:,:,1]) == (2,2,101)
16+
@test Array(sim)[1,:,:,:] == sim[1,:,:,:]
17+
@test Array(sim)[:,1,:,:] == sim[:,1,:,:]
18+
@test Array(sim)[:,:,1,:] == sim[:,:,1,:]
19+
@test Array(sim)[:,:,:,1] == sim[:,:,:,1]

test/downstream/ensemble_multi_prob.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@ prob3 = ODEProblem(sys3, [3.0, 3.0], (0.0, 1.0))
1717
ensemble_prob = EnsembleProblem([prob1, prob2, prob3])
1818
sol = solve(ensemble_prob, Tsit5(), EnsembleThreads())
1919
for i in 1:3
20-
@test sol[x, :][i] == sol.u[i][x]
21-
@test sol[y, :][i] == sol.u[i][y]
20+
@test sol[1,:,i] == sol.u[i][x]
21+
@test sol[2,:,i] == sol.u[i][y]
2222
end
2323
# Ensemble is a recursive array
24-
@test only.(sol(0.0, idxs = [x])) == sol[1, 1, :] == first.(sol[x, :])
25-
# TODO: fix the interpolation
26-
@test only.(sol(1.0, idxs = [x])) last.(sol[x, :])
24+
@test only.(sol(0.0, idxs = [x])) == sol[1, 1, :]
25+
@test only.(sol(1.0, idxs = [x])) [sol[i][1, end] for i in 1:3]

0 commit comments

Comments
 (0)