Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ PyCall = "1.96"
PythonCall = "0.9.15"
RCall = "0.14.0"
RecipesBase = "1.3.4"
RecursiveArrayTools = "3.26.0"
RecursiveArrayTools = "3.27.2"
Reexport = "1"
RuntimeGeneratedFunctions = "0.5.12"
SciMLOperators = "0.3.7"
Expand Down
14 changes: 0 additions & 14 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,20 +226,6 @@ end
end
end

Base.@propagate_inbounds function Base.getindex(
x::AbstractEnsembleSolution, s::Integer, i::Integer)
return x.u[s].u[i]
end

Base.@propagate_inbounds function Base.getindex(
x::AbstractEnsembleSolution, s::Integer, i2::Integer, i3::Integer, idxs::Integer...)
return x.u[s][i2, i3, idxs...]
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
return [xi[s] for xi in x.u]
end

function (sol::AbstractEnsembleSolution)(args...; kwargs...)
[s(args...; kwargs...) for s in sol]
end
Expand Down
18 changes: 14 additions & 4 deletions test/downstream/ensemble_diffeq.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
using OrdinaryDiffEq
using OrdinaryDiffEq, Test

prob = ODEProblem((u, p, t) -> 1.01u, 0.5, (0.0, 1.0))
A = [1 2
3 4]
prob = ODEProblem((u, p, t) -> A*u, ones(2,2), (0.0, 1.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
prob = ODEProblem((u, p, t) -> A*u, ones(2,2), (0.0, 1.0))
prob = ODEProblem((u, p, t) -> A * u, ones(2, 2), (0.0, 1.0))

function prob_func(prob, i, repeat)
remake(prob, u0 = rand() * prob.u0)
remake(prob, u0 = i * prob.u0)
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10)
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10, saveat=0.01)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10, saveat=0.01)
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10, saveat = 0.01)

@test sim isa EnsembleSolution
@test size(sim[1,:,:,:]) == (2,101,10)
@test size(sim[:,1,:,:]) == (2,101,10)
@test size(sim[:,:,1,:]) == (2,2,10)
@test size(sim[:,:,:,1]) == (2,2,101)
@test Array(sim)[1,:,:,:] == sim[1,:,:,:]
@test Array(sim)[:,1,:,:] == sim[:,1,:,:]
@test Array(sim)[:,:,1,:] == sim[:,:,1,:]
@test Array(sim)[:,:,:,1] == sim[:,:,:,1]
Comment on lines +12 to +19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@test size(sim[1,:,:,:]) == (2,101,10)
@test size(sim[:,1,:,:]) == (2,101,10)
@test size(sim[:,:,1,:]) == (2,2,10)
@test size(sim[:,:,:,1]) == (2,2,101)
@test Array(sim)[1,:,:,:] == sim[1,:,:,:]
@test Array(sim)[:,1,:,:] == sim[:,1,:,:]
@test Array(sim)[:,:,1,:] == sim[:,:,1,:]
@test Array(sim)[:,:,:,1] == sim[:,:,:,1]
@test size(sim[1, :, :, :]) == (2, 101, 10)
@test size(sim[:, 1, :, :]) == (2, 101, 10)
@test size(sim[:, :, 1, :]) == (2, 2, 10)
@test size(sim[:, :, :, 1]) == (2, 2, 101)
@test Array(sim)[1, :, :, :] == sim[1, :, :, :]
@test Array(sim)[:, 1, :, :] == sim[:, 1, :, :]
@test Array(sim)[:, :, 1, :] == sim[:, :, 1, :]
@test Array(sim)[:, :, :, 1] == sim[:, :, :, 1]

Loading