Skip to content

Commit 7a5b19a

Browse files
finalize splitthreads fixes and test distributed ensembles
1 parent 63872c6 commit 7a5b19a

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,10 @@ end
199199

200200
function thread_monte(prob,I,alg,procid,kwargs...)
201201
batch_data = Vector{Any}(undef,length(I))
202-
@show I
203202
let
203+
j = 0
204204
Threads.@threads for i in I
205+
j += 1
205206
iter = 1
206207
new_prob = prob.prob_func(deepcopy(prob.prob),i,iter)
207208
rerun = true
@@ -225,7 +226,7 @@ function thread_monte(prob,I,alg,procid,kwargs...)
225226
end
226227
rerun = _x[2]
227228
end
228-
batch_data[i - start + 1] = _x[1]
229+
batch_data[j] = _x[1]
229230
end
230231
end
231232
batch_data
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using Distributed
2+
addprocs(2)
3+
println("There are $(nprocs()) processes")
4+
@everywhere using OrdinaryDiffEq
5+
6+
@everywhere prob = ODEProblem((u,p,t)->1.01u,0.5,(0.0,1.0))
7+
@everywhere u0s = [rand()*prob.u0 for i in 1:2]
8+
@everywhere function prob_func(prob,i,repeat)
9+
println("Running trajectory $i")
10+
ODEProblem(prob.f,u0s[i],prob.tspan)
11+
end
12+
13+
ensemble_prob = EnsembleProblem(prob, prob_func=prob_func)
14+
sim = solve(ensemble_prob,Tsit5(),EnsembleSplitThreads(),trajectories=2)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ if !is_APPVEYOR && GROUP == "Downstream"
4848
@time @safetestset "DEDataArray" begin include("downstream/data_array_regression_tests.jl") end
4949
@time @safetestset "Concrete_solve Tests" begin include("downstream/concrete_solve_tests.jl") end
5050
@time @safetestset "AD Tests" begin include("downstream/ad_tests.jl") end
51+
@time @testset "Distributed Ensemble Tests" begin include("downstream/distributed_ensemble.jl") end
5152
end
5253

5354
if !is_APPVEYOR && GROUP == "GPU"

0 commit comments

Comments
 (0)