Skip to content

Commit 1bbecb3

Browse files
Merge pull request #501 from SciML/ensemble
fix thread-safety of splitthreads
2 parents 72ef5e4 + abff042 commit 1bbecb3

File tree

2 files changed

+41
-9
lines changed

2 files changed

+41
-9
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,8 @@ end
208208
function thread_monte(prob,I,alg,procid,kwargs...)
209209
batch_data = Vector{Any}(undef,length(I))
210210
let
211-
j = 0
212-
Threads.@threads for i in I
213-
j += 1
211+
Threads.@threads for j in 1:length(I)
212+
i = I[j]
214213
iter = 1
215214
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
216215
new_prob = prob.prob_func(_prob,i,iter)
Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,47 @@
11
using Distributed
22
addprocs(2)
33
println("There are $(nprocs()) processes")
4-
@everywhere using OrdinaryDiffEq
54

6-
@everywhere prob = ODEProblem((u,p,t)->1.01u,0.5,(0.0,1.0))
7-
@everywhere u0s = [rand()*prob.u0 for i in 1:2]
8-
@everywhere function prob_func(prob,i,repeat)
9-
println("Running trajectory $i")
10-
ODEProblem(prob.f,u0s[i],prob.tspan)
5+
@everywhere begin
6+
using Pkg
7+
Pkg.activate("downstream")
8+
Pkg.develop(PackageSpec(path=joinpath(pwd(), "..")))
9+
Pkg.instantiate()
10+
using OrdinaryDiffEq
11+
prob = ODEProblem((u,p,t)->1.01u,0.5,(0.0,1.0))
12+
u0s = [rand()*prob.u0 for i in 1:2]
13+
function prob_func(prob,i,repeat)
14+
println("Running trajectory $i")
15+
ODEProblem(prob.f,u0s[i],prob.tspan)
16+
end
1117
end
1218

1319
ensemble_prob = EnsembleProblem(prob, prob_func=prob_func)
1420
sim = solve(ensemble_prob,Tsit5(),EnsembleSplitThreads(),trajectories=2)
21+
22+
@everywhere function lorenz!(du,u,p,t)
23+
du[1] = 10.0*(u[2]-u[1])
24+
du[2] = u[1]*(28.0-u[3]) - u[2]
25+
du[3] = u[1]*u[2] - (8/3)*u[3]
26+
end
27+
28+
u0 = [1.0, 0.0, 0.0]
29+
tspan = (0.0, 100.0)
30+
p = [1, 2.0, 3]
31+
prob = ODEProblem(lorenz!, u0, tspan, p)
32+
33+
@everywhere function prob_func(prob,i,repeat)
34+
prob = remake(prob, tspan=(rand(), 100.0), p=rand(3))
35+
return prob
36+
end
37+
38+
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func, safetycopy=true)
39+
40+
println("Running EnsembleSerial()")
41+
@test length(solve(ensemble_prob, Tsit5(), EnsembleSerial(), trajectories=100)) == 100
42+
println("Running EnsembleThreads()")
43+
@test length(solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories=100)) == 100
44+
println("Running EnsembleDistributed()")
45+
@test length(solve(ensemble_prob, Tsit5(), EnsembleDistributed(), trajectories=100)) == 100
46+
println("Running EnsembleSplitThreads()")
47+
@test length(solve(ensemble_prob, Tsit5(), EnsembleSplitThreads(), trajectories=100)) == 100

0 commit comments

Comments
 (0)