Skip to content

Commit 28a6572

Browse files
fix thread-safety of splitthreads
1 parent 72ef5e4 commit 28a6572

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,8 @@ end
208208
function thread_monte(prob,I,alg,procid,kwargs...)
209209
batch_data = Vector{Any}(undef,length(I))
210210
let
211-
j = 0
212-
Threads.@threads for i in I
213-
j += 1
211+
Threads.@threads for j in 1:length(I)
212+
i = I[j]
214213
iter = 1
215214
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
216215
new_prob = prob.prob_func(_prob,i,iter)

test/downstream/distributed_ensemble.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,30 @@ end
1212

1313
ensemble_prob = EnsembleProblem(prob, prob_func=prob_func)
1414
sim = solve(ensemble_prob,Tsit5(),EnsembleSplitThreads(),trajectories=2)
15+
16+
@everywhere function lorenz!(du,u,p,t)
17+
du[1] = 10.0*(u[2]-u[1])
18+
du[2] = u[1]*(28.0-u[3]) - u[2]
19+
du[3] = u[1]*u[2] - (8/3)*u[3]
20+
end
21+
22+
u0 = [1.0, 0.0, 0.0]
23+
tspan = (0.0, 100.0)
24+
p = [1, 2.0, 3]
25+
prob = ODEProblem(lorenz!, u0, tspan, p)
26+
27+
@everywhere function prob_func(prob,i,repeat)
28+
prob = remake(prob, tspan=(rand(), 100.0), p=rand(3))
29+
return prob
30+
end
31+
32+
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func, safetycopy=true)
33+
34+
println("Running EnsembleSerial()")
35+
@test length(solve(ensemble_prob, Tsit5(), EnsembleSerial(), trajectories=100)) == 100
36+
println("Running EnsembleThreads()")
37+
@test length(solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories=100)) == 100
38+
println("Running EnsembleDistributed()")
39+
@test length(solve(ensemble_prob, Tsit5(), EnsembleDistributed(), trajectories=100)) == 100
40+
println("Running EnsembleSplitThreads()")
41+
@test length(solve(ensemble_prob, Tsit5(), EnsembleSplitThreads(), trajectories=100)) == 100

0 commit comments

Comments
 (0)