Skip to content

Commit 4c7f04d

Browse files
Merge pull request #518 from SciML/monte
fix multithreaded thread safety and clean up
2 parents 6ebad1a + 0be6652 commit 4c7f04d

File tree

3 files changed

+43
-111
lines changed

3 files changed

+43
-111
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 16 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -168,63 +168,31 @@ end
168168

169169
function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kwargs...)
170170

171+
if length(II) == 1 || Threads.nthreads() == 1
172+
return solve_batch(prob,alg,EnsembleSerial(),II,pmap_batch_size;kwargs...)
173+
end
174+
171175
if typeof(prob.prob) <: AbstractJumpProblem && length(II) != 1
172176
probs = [deepcopy(prob.prob) for i in 1:Threads.nthreads()]
173177
else
174178
probs = prob.prob
175179
end
176180

177-
function multithreaded_batch(batch_idx)
178-
i = II[batch_idx]
179-
iter = 1
180-
_prob = if prob.safetycopy
181-
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
182-
else
183-
probs isa Vector ? probs[Threads.threadid()] : probs
184-
end
185-
new_prob = prob.prob_func(_prob,i,iter)
186-
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
187-
if !(typeof(x) <: Tuple)
188-
@warn("output_func should return (out,rerun). See docs for updated details")
189-
_x = (x,false)
190-
else
191-
_x = x
192-
end
193-
rerun = _x[2]
194-
195-
while rerun
196-
iter += 1
197-
_prob = if prob.safetycopy
198-
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
199-
else
200-
probs isa Vector ? probs[Threads.threadid()] : probs
201-
end
202-
new_prob = prob.prob_func(_prob,i,iter)
203-
x = prob.output_func(solve(new_prob,alg;alias_jumps=true,kwargs...),i)
204-
if !(typeof(x) <: Tuple)
205-
@warn("output_func should return (out,rerun). See docs for updated details")
206-
_x = (x,false)
207-
else
208-
_x = x
209-
end
210-
rerun = _x[2]
211-
end
212-
_x[1]
213-
end
214-
215181
#batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
216-
batch_data = Vector{Any}(undef,length(II))
182+
183+
batch_data = Vector{Any}(undef,Threads.nthreads())
184+
batch_size = length(II)÷Threads.nthreads()
217185

218186
let
219-
if length(II) == 1 || Threads.nthreads() == 1
220-
for batch_idx in axes(batch_data, 1)
221-
batch_data[batch_idx] = multithreaded_batch(batch_idx)
222-
end
223-
else
224-
Threads.@threads for batch_idx in axes(batch_data, 1)
225-
batch_data[batch_idx] = multithreaded_batch(batch_idx)
226-
end
187+
Threads.@threads for i in 1:Threads.nthreads()
188+
if i == Threads.nthreads()
189+
I_local = II[(batch_size*(i-1)+1):end]
190+
else
191+
I_local = II[(batch_size*(i-1)+1):(batch_size*i)]
192+
end
193+
batch_data[i] = solve_batch(prob,alg,EnsembleSerial(),I_local,pmap_batch_size;kwargs...)
227194
end
195+
batch_data = reduce(vcat,batch_data)
228196
end
229197
tighten_container_eltype(batch_data)
230198
end
@@ -240,71 +208,8 @@ function solve_batch(prob,alg,::EnsembleSplitThreads,II,pmap_batch_size;kwargs..
240208
else
241209
I_local = II[(batch_size*(i-1)+1):(batch_size*i)]
242210
end
243-
thread_monte(prob,I_local,alg,i;kwargs...)
211+
solve_batch(prob,alg,EnsembleThreads(),I_local,pmap_batch_size;kwargs...)
244212
end
245213
end
246214
reduce(vcat,batch_data)
247215
end
248-
249-
function thread_monte(prob,II,alg,procid;kwargs...)
250-
251-
if typeof(prob.prob) <: AbstractJumpProblem && length(II) != 1
252-
probs = [deepcopy(prob.prob) for i in 1:Threads.nthreads()]
253-
else
254-
probs = prob.prob
255-
end
256-
257-
function multithreaded_batch(j)
258-
i = II[j]
259-
iter = 1
260-
_prob = if prob.safetycopy
261-
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
262-
else
263-
probs isa Vector ? probs[Threads.threadid()] : probs
264-
end
265-
new_prob = prob.prob_func(_prob,i,iter)
266-
rerun = true
267-
x = prob.output_func(solve(new_prob,alg;alias_jumps=true,kwargs...),i)
268-
if !(typeof(x) <: Tuple)
269-
@warn("output_func should return (out,rerun). See docs for updated details")
270-
_x = (x,false)
271-
else
272-
_x = x
273-
end
274-
rerun = _x[2]
275-
while rerun
276-
iter += 1
277-
_prob = if prob.safetycopy
278-
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
279-
else
280-
probs isa Vector ? probs[Threads.threadid()] : probs
281-
end
282-
new_prob = prob.prob_func(_prob,i,iter)
283-
x = prob.output_func(solve(new_prob,alg;alias_jumps=true,kwargs...),i)
284-
if !(typeof(x) <: Tuple)
285-
@warn("output_func should return (out,rerun). See docs for updated details")
286-
_x = (x,false)
287-
else
288-
_x = x
289-
end
290-
rerun = _x[2]
291-
end
292-
_x[1]
293-
end
294-
295-
#batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
296-
batch_data = Vector{Any}(undef,length(II))
297-
298-
let
299-
if length(II) == 1 || Threads.nthreads() == 1
300-
for batch_idx in axes(batch_data, 1)
301-
batch_data[batch_idx] = multithreaded_batch(batch_idx)
302-
end
303-
else
304-
Threads.@threads for batch_idx in axes(batch_data, 1)
305-
batch_data[batch_idx] = multithreaded_batch(batch_idx)
306-
end
307-
end
308-
end
309-
tighten_container_eltype(batch_data)
310-
end
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using OrdinaryDiffEq
2+
function f(du,u,p,t)
3+
du[1] = 1.01*u[1]
4+
end
5+
u0 = [0.0, 0.0]
6+
tspan = (0.0,1.0)
7+
prob = ODEProblem(f,u0,tspan)
8+
n = 100
9+
10+
initial_conditions = range(0, stop=1, length=n)
11+
function prob_func(prob,i,repeat)
12+
prob.u0[1] = initial_conditions[i]
13+
prob
14+
end
15+
ensemble_prob = EnsembleProblem(prob,prob_func=prob_func)
16+
sim_1 = solve(ensemble_prob,Tsit5()
17+
,EnsembleThreads(),
18+
trajectories=100)
19+
sim_2 = solve(ensemble_prob,Tsit5()
20+
,EnsembleDistributed(),
21+
trajectories=100)
22+
ss_sol_1 = hcat(collect(EnsembleAnalysis.get_timepoint(sim_1,0))...);
23+
ss_sol_2 = hcat(collect(EnsembleAnalysis.get_timepoint(sim_2,0))...);
24+
25+
ss_sol_1[1,:] == initial_conditions
26+
ss_sol_2[1,:] == initial_conditions

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ if !is_APPVEYOR && GROUP == "Downstream"
3838
@time @safetestset "Null Parameters" begin include("downstream/null_params_test.jl") end
3939
@time @safetestset "Ensemble Simulations" begin include("downstream/ensemble.jl") end
4040
@time @safetestset "Ensemble Analysis" begin include("downstream/ensemble_analysis.jl") end
41+
@time @safetestset "Ensemble Thread Safety" begin include("downstream/ensemble_thread_safety.jl") end
4142
@time @safetestset "Inference Tests" begin include("downstream/inference.jl") end
4243
@time @safetestset "Default linsolve with structure" begin include("downstream/default_linsolve_structure.jl") end
4344
@time @safetestset "Callback Merging Tests" begin include("downstream/callback_merging.jl") end

0 commit comments

Comments
 (0)